feat: add download target type for fetching files from URLs
This commit is contained in:
parent
a4600df4d5
commit
c1ad6e6e3c
14 changed files with 296 additions and 74 deletions
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import override
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -16,7 +16,7 @@ from hokusai.builder import (
|
|||
_resolve_prompt, # pyright: ignore[reportPrivateUsage]
|
||||
run_build,
|
||||
)
|
||||
from hokusai.config import ProjectConfig, TargetConfig
|
||||
from hokusai.config import GenerateTargetConfig, ProjectConfig
|
||||
from hokusai.providers import Provider
|
||||
from hokusai.providers.models import Capability, ModelInfo
|
||||
from hokusai.state import load_state
|
||||
|
|
@ -57,7 +57,7 @@ class FakeTextProvider(Provider):
|
|||
async def generate(
|
||||
self,
|
||||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
target_config: GenerateTargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
|
|
@ -78,7 +78,7 @@ class FakeImageProvider(Provider):
|
|||
async def generate(
|
||||
self,
|
||||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
target_config: GenerateTargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
|
|
@ -99,7 +99,7 @@ class FailingTextProvider(Provider):
|
|||
async def generate(
|
||||
self,
|
||||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
target_config: GenerateTargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
|
|
@ -295,7 +295,7 @@ class TestRunBuild:
|
|||
|
||||
async def selective_generate(
|
||||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
target_config: GenerateTargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
|
|
@ -426,3 +426,145 @@ class TestRunBuild:
|
|||
|
||||
assert set(result.built) == {"left.md", "right.md", "merge.txt"}
|
||||
assert result.failed == {}
|
||||
|
||||
|
||||
class TestDownloadTarget:
|
||||
"""Tests for download-type targets that fetch files from URLs."""
|
||||
|
||||
async def test_download_target_fetches_url(
|
||||
self, project_dir: Path, write_config: WriteConfig
|
||||
) -> None:
|
||||
config = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"fish.png": {"download": "https://example.com/fish.png"},
|
||||
}
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake image bytes"
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with (
|
||||
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
|
||||
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await run_build(config, project_dir, _PROJECT)
|
||||
|
||||
assert result.built == ["fish.png"]
|
||||
assert (project_dir / "fish.png").read_bytes() == b"fake image bytes"
|
||||
mock_client.get.assert_called_once_with( # pyright: ignore[reportAny]
|
||||
"https://example.com/fish.png", follow_redirects=True
|
||||
)
|
||||
|
||||
async def test_download_target_incremental_skip(
|
||||
self, project_dir: Path, write_config: WriteConfig
|
||||
) -> None:
|
||||
config = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"fish.png": {"download": "https://example.com/fish.png"},
|
||||
}
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake image bytes"
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with (
|
||||
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
|
||||
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
r1 = await run_build(config, project_dir, _PROJECT)
|
||||
assert r1.built == ["fish.png"]
|
||||
|
||||
r2 = await run_build(config, project_dir, _PROJECT)
|
||||
assert r2.skipped == ["fish.png"]
|
||||
assert r2.built == []
|
||||
|
||||
async def test_download_target_rebuild_on_url_change(
|
||||
self, project_dir: Path, write_config: WriteConfig
|
||||
) -> None:
|
||||
config1 = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"fish.png": {"download": "https://example.com/fish-v1.png"},
|
||||
}
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"v1 bytes"
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with (
|
||||
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
|
||||
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
r1 = await run_build(config1, project_dir, _PROJECT)
|
||||
assert r1.built == ["fish.png"]
|
||||
|
||||
config2 = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"fish.png": {"download": "https://example.com/fish-v2.png"},
|
||||
}
|
||||
}
|
||||
)
|
||||
mock_response.content = b"v2 bytes"
|
||||
|
||||
r2 = await run_build(config2, project_dir, _PROJECT)
|
||||
assert r2.built == ["fish.png"]
|
||||
assert (project_dir / "fish.png").read_bytes() == b"v2 bytes"
|
||||
|
||||
async def test_download_target_as_dependency(
|
||||
self, project_dir: Path, write_config: WriteConfig
|
||||
) -> None:
|
||||
config = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"fish.png": {"download": "https://example.com/fish.png"},
|
||||
"description.txt": {
|
||||
"prompt": "Describe the fish",
|
||||
"inputs": ["fish.png"],
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake fish image"
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with (
|
||||
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
|
||||
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await run_build(config, project_dir, _PROJECT)
|
||||
|
||||
assert "fish.png" in result.built
|
||||
assert "description.txt" in result.built
|
||||
assert (project_dir / "fish.png").read_bytes() == b"fake fish image"
|
||||
assert (project_dir / "description.txt").exists()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue