From c1ad6e6e3cbc6002bf343db6f58003d620d54478 Mon Sep 17 00:00:00 2001 From: Konstantin Fickel Date: Fri, 20 Feb 2026 21:02:15 +0100 Subject: [PATCH] feat: add download target type for fetching files from URLs --- hokusai/builder.py | 57 ++++++++++- hokusai/cli.py | 8 +- hokusai/config.py | 28 +++++- hokusai/graph.py | 5 +- hokusai/providers/__init__.py | 4 +- hokusai/providers/blackforest.py | 4 +- hokusai/providers/mistral.py | 4 +- hokusai/providers/openai_image.py | 4 +- hokusai/providers/openai_text.py | 4 +- hokusai/resolve.py | 6 +- tests/test_builder.py | 154 ++++++++++++++++++++++++++++-- tests/test_config.py | 8 +- tests/test_providers.py | 40 ++++---- tests/test_resolve.py | 44 +++++---- 14 files changed, 296 insertions(+), 74 deletions(-) diff --git a/hokusai/builder.py b/hokusai/builder.py index 7e3a22d..a3d4653 100644 --- a/hokusai/builder.py +++ b/hokusai/builder.py @@ -9,7 +9,9 @@ from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path -from hokusai.config import ProjectConfig +import httpx + +from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target from hokusai.providers import Provider from hokusai.providers.blackforest import BlackForestProvider @@ -73,6 +75,8 @@ def _collect_dep_files( ) -> list[Path]: """Collect all dependency file paths for a target.""" target_cfg = config.targets[target_name] + if not isinstance(target_cfg, GenerateTargetConfig): + return [] deps: list[str] = list(target_cfg.inputs) deps.extend(target_cfg.reference_images) deps.extend(target_cfg.control_images) @@ -82,6 +86,8 @@ def _collect_dep_files( def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]: """Collect extra parameters that affect rebuild decisions.""" target_cfg = config.targets[target_name] + if not isinstance(target_cfg, GenerateTargetConfig): + return {} params: dict[str, object] = {} if target_cfg.width is not None: params["width"] = target_cfg.width @@ -97,6 +103,8 @@ def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]: """Collect all dependency names (inputs + reference_images + control_images).""" target_cfg = config.targets[target_name] + if not isinstance(target_cfg, GenerateTargetConfig): + return [] deps: list[str] = list(target_cfg.inputs) deps.extend(target_cfg.reference_images) deps.extend(target_cfg.control_images) @@ -128,6 +136,18 @@ def _build_provider_index(providers: list[Provider]) -> dict[str, Provider]: return index +async def _download_target( + target_name: str, + target_cfg: DownloadTargetConfig, + project_dir: Path, +) -> None: + """Download a file from a URL and write it to *project_dir / target_name*.""" + async with httpx.AsyncClient(headers={"User-Agent": "hokusai"}) as client: + response = await client.get(target_cfg.download, follow_redirects=True) + _ = response.raise_for_status() + _ = (project_dir / target_name).write_bytes(response.content) + + async def _build_single_target( target_name: str, config: ProjectConfig, @@ -136,6 +156,11 @@ async def _build_single_target( ) -> None: """Build a single target by dispatching to the appropriate provider.""" target_cfg = config.targets[target_name] + + if isinstance(target_cfg, DownloadTargetConfig): + await _download_target(target_name, target_cfg, project_dir) + return + model_info = resolve_model(target_name, target_cfg, config.defaults) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) @@ -227,6 +252,9 @@ def _should_skip_failed_dep( return any(d in result.failed for d in _collect_all_deps(target_name, config)) +_DOWNLOAD_MODEL_SENTINEL = "__download__" + + def _is_dirty( target_name: str, config: ProjectConfig, @@ -235,6 +263,18 @@ def _is_dirty( ) -> bool: """Check if a target needs rebuilding.""" target_cfg = config.targets[target_name] + + if isinstance(target_cfg, DownloadTargetConfig): + return is_target_dirty( + target_name, + resolved_prompt=target_cfg.download, + model=_DOWNLOAD_MODEL_SENTINEL, + dep_files=[], + extra_params={}, + state=state, + project_dir=project_dir, + ) + model_info = resolve_model(target_name, target_cfg, config.defaults) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) dep_files = _collect_dep_files(target_name, config, project_dir) @@ -260,6 +300,8 @@ def _has_provider( ) -> bool: """Check that the required provider is available; record failure if not.""" target_cfg = config.targets[target_name] + if isinstance(target_cfg, DownloadTargetConfig): + return True model_info = resolve_model(target_name, target_cfg, config.defaults) if model_info.name not in provider_index: msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables" @@ -302,15 +344,22 @@ def _process_outcomes( on_progress(BuildEvent.TARGET_FAILED, name, str(error)) else: target_cfg = config.targets[name] - model_info = resolve_model(name, target_cfg, config.defaults) - resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) + + if isinstance(target_cfg, DownloadTargetConfig): + resolved_prompt = target_cfg.download + model_name = _DOWNLOAD_MODEL_SENTINEL + else: + model_info = resolve_model(name, target_cfg, config.defaults) + resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) + model_name = model_info.name + dep_files = _collect_dep_files(name, config, project_dir) extra = _collect_extra_params(name, config) record_target_state( name, resolved_prompt=resolved_prompt, - model=model_info.name, + model=model_name, dep_files=dep_files, extra_params=extra, state=state, diff --git a/hokusai/cli.py b/hokusai/cli.py index a449caa..95462f6 100644 --- a/hokusai/cli.py +++ b/hokusai/cli.py @@ -257,18 +257,12 @@ def init() -> None: content = f"""\ # {name} - hokusai project -defaults: - image_model: flux-2-pro - targets: great_wave.png: prompt: >- A recreation of Hokusai's "The Great Wave off Kanagawa", but instead of boats and people, paint brushes, canvases, and framed paintings are - swimming and tumbling in the towering wave. Oil paint tubes burst open - and trail ribbons of colour through the spray. The iconic Mount Fuji - sits serenely in the background. Ukiyo-e woodblock print style with - vivid modern pigment colours. + swimming and tumbling in the towering wave. """ _ = dest.write_text(content) click.echo(click.style(" created ", fg="green") + click.style(filename, bold=True)) diff --git a/hokusai/config.py b/hokusai/config.py index 76c293f..ac397d8 100644 --- a/hokusai/config.py +++ b/hokusai/config.py @@ -4,10 +4,10 @@ from __future__ import annotations import enum from pathlib import Path -from typing import Self +from typing import Annotated, Self import yaml -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Discriminator, Tag, model_validator from hokusai.providers.models import Capability @@ -29,8 +29,8 @@ class Defaults(BaseModel): image_model: str = "flux-2-pro" -class TargetConfig(BaseModel): - """Configuration for a single build target.""" +class GenerateTargetConfig(BaseModel): + """Configuration for a target that generates an artifact via an AI provider.""" prompt: str model: str | None = None @@ -41,6 +41,26 @@ class TargetConfig(BaseModel): height: int | None = None +class DownloadTargetConfig(BaseModel): + """Configuration for a target that downloads a file from a URL.""" + + download: str + + +def _target_discriminator(raw: object) -> str: + """Discriminate between generate and download target configs.""" + if isinstance(raw, dict) and "download" in raw: + return "download" + return "generate" + + +TargetConfig = Annotated[ + Annotated[GenerateTargetConfig, Tag("generate")] + | Annotated[DownloadTargetConfig, Tag("download")], + Discriminator(_target_discriminator), +] + + class ProjectConfig(BaseModel): """Top-level configuration parsed from ``.hokusai.yaml``.""" diff --git a/hokusai/graph.py b/hokusai/graph.py index 3d83a29..1194049 100644 --- a/hokusai/graph.py +++ b/hokusai/graph.py @@ -6,7 +6,7 @@ from pathlib import Path import networkx as nx -from hokusai.config import ProjectConfig +from hokusai.config import GenerateTargetConfig, ProjectConfig def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]: @@ -25,6 +25,9 @@ def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]: for target_name, target_cfg in config.targets.items(): graph.add_node(target_name) + if not isinstance(target_cfg, GenerateTargetConfig): + continue + deps: list[str] = list(target_cfg.inputs) deps.extend(target_cfg.reference_images) deps.extend(target_cfg.control_images) diff --git a/hokusai/providers/__init__.py b/hokusai/providers/__init__.py index b91dc3b..20390a5 100644 --- a/hokusai/providers/__init__.py +++ b/hokusai/providers/__init__.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING from hokusai.providers.models import ModelInfo if TYPE_CHECKING: - from hokusai.config import TargetConfig + from hokusai.config import GenerateTargetConfig class Provider(abc.ABC): @@ -24,7 +24,7 @@ class Provider(abc.ABC): async def generate( self, target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, diff --git a/hokusai/providers/blackforest.py b/hokusai/providers/blackforest.py index d139964..ba7d531 100644 --- a/hokusai/providers/blackforest.py +++ b/hokusai/providers/blackforest.py @@ -8,7 +8,7 @@ from typing import override import httpx -from hokusai.config import TargetConfig +from hokusai.config import GenerateTargetConfig from hokusai.providers import Provider from hokusai.providers.bfl import BFLClient from hokusai.providers.models import Capability, ModelInfo @@ -123,7 +123,7 @@ class BlackForestProvider(Provider): async def generate( self, target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, diff --git a/hokusai/providers/mistral.py b/hokusai/providers/mistral.py index 03a0de8..56afb37 100644 --- a/hokusai/providers/mistral.py +++ b/hokusai/providers/mistral.py @@ -9,7 +9,7 @@ from typing import override from mistralai import Mistral, models -from hokusai.config import IMAGE_EXTENSIONS, TargetConfig +from hokusai.config import IMAGE_EXTENSIONS, GenerateTargetConfig from hokusai.providers import Provider from hokusai.providers.models import Capability, ModelInfo @@ -63,7 +63,7 @@ class MistralProvider(Provider): async def generate( self, target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, diff --git a/hokusai/providers/openai_image.py b/hokusai/providers/openai_image.py index 398b64e..976096e 100644 --- a/hokusai/providers/openai_image.py +++ b/hokusai/providers/openai_image.py @@ -10,7 +10,7 @@ import httpx from openai import AsyncOpenAI from openai.types.images_response import ImagesResponse -from hokusai.config import TargetConfig +from hokusai.config import GenerateTargetConfig from hokusai.providers import Provider from hokusai.providers.models import Capability, ModelInfo @@ -109,7 +109,7 @@ class OpenAIImageProvider(Provider): async def generate( self, target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, diff --git a/hokusai/providers/openai_text.py b/hokusai/providers/openai_text.py index d205aa2..2e4c733 100644 --- a/hokusai/providers/openai_text.py +++ b/hokusai/providers/openai_text.py @@ -15,7 +15,7 @@ from openai.types.chat import ( ChatCompletionUserMessageParam, ) -from hokusai.config import IMAGE_EXTENSIONS, TargetConfig +from hokusai.config import IMAGE_EXTENSIONS, GenerateTargetConfig from hokusai.providers import Provider from hokusai.providers.models import Capability, ModelInfo @@ -120,7 +120,7 @@ class OpenAITextProvider(Provider): async def generate( self, target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, diff --git a/hokusai/resolve.py b/hokusai/resolve.py index 63ee2cc..8a9acd3 100644 --- a/hokusai/resolve.py +++ b/hokusai/resolve.py @@ -6,7 +6,7 @@ from hokusai.config import ( IMAGE_EXTENSIONS, TEXT_EXTENSIONS, Defaults, - TargetConfig, + GenerateTargetConfig, TargetType, target_type_from_capabilities, ) @@ -14,7 +14,7 @@ from hokusai.providers.models import Capability, ModelInfo def infer_required_capabilities( - target_name: str, target: TargetConfig + target_name: str, target: GenerateTargetConfig ) -> frozenset[Capability]: """Infer the capabilities a model must have based on filename and config. @@ -42,7 +42,7 @@ def infer_required_capabilities( def resolve_model( - target_name: str, target: TargetConfig, defaults: Defaults + target_name: str, target: GenerateTargetConfig, defaults: Defaults ) -> ModelInfo: """Return the effective model for a target, validated against required capabilities. diff --git a/tests/test_builder.py b/tests/test_builder.py index 70031a9..2db3517 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -5,7 +5,7 @@ from __future__ import annotations from collections.abc import Callable from pathlib import Path from typing import override -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -16,7 +16,7 @@ from hokusai.builder import ( _resolve_prompt, # pyright: ignore[reportPrivateUsage] run_build, ) -from hokusai.config import ProjectConfig, TargetConfig +from hokusai.config import GenerateTargetConfig, ProjectConfig from hokusai.providers import Provider from hokusai.providers.models import Capability, ModelInfo from hokusai.state import load_state @@ -57,7 +57,7 @@ class FakeTextProvider(Provider): async def generate( self, target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, @@ -78,7 +78,7 @@ class FakeImageProvider(Provider): async def generate( self, target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, @@ -99,7 +99,7 @@ class FailingTextProvider(Provider): async def generate( self, target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, @@ -295,7 +295,7 @@ class TestRunBuild: async def selective_generate( target_name: str, - target_config: TargetConfig, + target_config: GenerateTargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, @@ -426,3 +426,145 @@ class TestRunBuild: assert set(result.built) == {"left.md", "right.md", "merge.txt"} assert result.failed == {} + + +class TestDownloadTarget: + """Tests for download-type targets that fetch files from URLs.""" + + async def test_download_target_fetches_url( + self, project_dir: Path, write_config: WriteConfig + ) -> None: + config = write_config( + { + "targets": { + "fish.png": {"download": "https://example.com/fish.png"}, + } + } + ) + mock_response = MagicMock() + mock_response.content = b"fake image bytes" + mock_response.raise_for_status = MagicMock() + + with ( + patch("hokusai.builder._create_providers", return_value=_fake_providers()), + patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + result = await run_build(config, project_dir, _PROJECT) + + assert result.built == ["fish.png"] + assert (project_dir / "fish.png").read_bytes() == b"fake image bytes" + mock_client.get.assert_called_once_with( # pyright: ignore[reportAny] + "https://example.com/fish.png", follow_redirects=True + ) + + async def test_download_target_incremental_skip( + self, project_dir: Path, write_config: WriteConfig + ) -> None: + config = write_config( + { + "targets": { + "fish.png": {"download": "https://example.com/fish.png"}, + } + } + ) + mock_response = MagicMock() + mock_response.content = b"fake image bytes" + mock_response.raise_for_status = MagicMock() + + with ( + patch("hokusai.builder._create_providers", return_value=_fake_providers()), + patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + r1 = await run_build(config, project_dir, _PROJECT) + assert r1.built == ["fish.png"] + + r2 = await run_build(config, project_dir, _PROJECT) + assert r2.skipped == ["fish.png"] + assert r2.built == [] + + async def test_download_target_rebuild_on_url_change( + self, project_dir: Path, write_config: WriteConfig + ) -> None: + config1 = write_config( + { + "targets": { + "fish.png": {"download": "https://example.com/fish-v1.png"}, + } + } + ) + mock_response = MagicMock() + mock_response.content = b"v1 bytes" + mock_response.raise_for_status = MagicMock() + + with ( + patch("hokusai.builder._create_providers", return_value=_fake_providers()), + patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + r1 = await run_build(config1, project_dir, _PROJECT) + assert r1.built == ["fish.png"] + + config2 = write_config( + { + "targets": { + "fish.png": {"download": "https://example.com/fish-v2.png"}, + } + } + ) + mock_response.content = b"v2 bytes" + + r2 = await run_build(config2, project_dir, _PROJECT) + assert r2.built == ["fish.png"] + assert (project_dir / "fish.png").read_bytes() == b"v2 bytes" + + async def test_download_target_as_dependency( + self, project_dir: Path, write_config: WriteConfig + ) -> None: + config = write_config( + { + "targets": { + "fish.png": {"download": "https://example.com/fish.png"}, + "description.txt": { + "prompt": "Describe the fish", + "inputs": ["fish.png"], + }, + } + } + ) + mock_response = MagicMock() + mock_response.content = b"fake fish image" + mock_response.raise_for_status = MagicMock() + + with ( + patch("hokusai.builder._create_providers", return_value=_fake_providers()), + patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls, + ): + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + result = await run_build(config, project_dir, _PROJECT) + + assert "fish.png" in result.built + assert "description.txt" in result.built + assert (project_dir / "fish.png").read_bytes() == b"fake fish image" + assert (project_dir / "description.txt").exists() diff --git a/tests/test_config.py b/tests/test_config.py index 612d1b8..83d25b4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ from pathlib import Path import pytest import yaml -from hokusai.config import load_config +from hokusai.config import GenerateTargetConfig, load_config class TestLoadConfig: @@ -21,7 +21,9 @@ class TestLoadConfig: config = load_config(config_path) assert "out.txt" in config.targets - assert config.targets["out.txt"].prompt == "hello" + target = config.targets["out.txt"] + assert isinstance(target, GenerateTargetConfig) + assert target.prompt == "hello" assert config.defaults.text_model == "pixtral-large-latest" assert config.defaults.image_model == "flux-2-pro" @@ -56,6 +58,7 @@ class TestLoadConfig: assert config.defaults.image_model == "custom-image" banner = config.targets["banner.png"] + assert isinstance(banner, GenerateTargetConfig) assert banner.model == "flux-dev" assert banner.width == 1920 assert banner.height == 480 @@ -63,6 +66,7 @@ class TestLoadConfig: assert banner.control_images == ["ctrl.png"] story = config.targets["story.md"] + assert isinstance(story, GenerateTargetConfig) assert story.model is None assert story.inputs == ["banner.png"] diff --git a/tests/test_providers.py b/tests/test_providers.py index d690c9a..1d63e01 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from hokusai.config import TargetConfig +from hokusai.config import GenerateTargetConfig from hokusai.providers.bfl import BFLResult from hokusai.providers.blackforest import BlackForestProvider from hokusai.providers.blackforest import ( @@ -81,7 +81,7 @@ class TestBlackForestProvider: async def test_basic_image_generation( self, project_dir: Path, image_bytes: bytes ) -> None: - target_config = TargetConfig(prompt="A red square") + target_config = GenerateTargetConfig(prompt="A red square") bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( @@ -107,7 +107,7 @@ class TestBlackForestProvider: async def test_image_with_dimensions( self, project_dir: Path, image_bytes: bytes ) -> None: - target_config = TargetConfig(prompt="A banner", width=1920, height=480) + target_config = GenerateTargetConfig(prompt="A banner", width=1920, height=480) bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( @@ -138,7 +138,9 @@ class TestBlackForestProvider: ref_path = project_dir / "ref.png" _ = ref_path.write_bytes(b"reference image data") - target_config = TargetConfig(prompt="Like this", reference_images=["ref.png"]) + target_config = GenerateTargetConfig( + prompt="Like this", reference_images=["ref.png"] + ) bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( @@ -171,7 +173,7 @@ class TestBlackForestProvider: _ = ref1.write_bytes(b"ref1 data") _ = ref2.write_bytes(b"ref2 data") - target_config = TargetConfig( + target_config = GenerateTargetConfig( prompt="Combine", reference_images=["ref1.png", "ref2.png"] ) bfl_result, mock_http = _make_bfl_mocks(image_bytes) @@ -204,7 +206,9 @@ class TestBlackForestProvider: ref_path = project_dir / "ref.png" _ = ref_path.write_bytes(b"reference image data") - target_config = TargetConfig(prompt="Edit", reference_images=["ref.png"]) + target_config = GenerateTargetConfig( + prompt="Edit", reference_images=["ref.png"] + ) bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( @@ -229,7 +233,7 @@ class TestBlackForestProvider: assert inputs["input_image"] == encode_image_b64(ref_path) async def test_image_no_sample_url_raises(self, project_dir: Path) -> None: - target_config = TargetConfig(prompt="x") + target_config = GenerateTargetConfig(prompt="x") with patch("hokusai.providers.blackforest.BFLClient") as mock_cls: from hokusai.providers.bfl import BFLError @@ -261,7 +265,7 @@ class TestMistralProvider: """Test MistralProvider with mocked Mistral client.""" async def test_basic_text_generation(self, project_dir: Path) -> None: - target_config = TargetConfig(prompt="Write a poem") + target_config = GenerateTargetConfig(prompt="Write a poem") response = _make_text_response("Roses are red...") with patch("hokusai.providers.mistral.Mistral") as mock_cls: @@ -282,7 +286,7 @@ class TestMistralProvider: async def test_text_with_text_input(self, project_dir: Path) -> None: _ = (project_dir / "source.txt").write_text("Source material here") - target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"]) + target_config = GenerateTargetConfig(prompt="Summarize", inputs=["source.txt"]) response = _make_text_response("Summary: ...") with patch("hokusai.providers.mistral.Mistral") as mock_cls: @@ -306,7 +310,9 @@ class TestMistralProvider: async def test_text_with_image_input(self, project_dir: Path) -> None: _ = (project_dir / "photo.png").write_bytes(b"\x89PNG") - target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"]) + target_config = GenerateTargetConfig( + prompt="Describe this image", inputs=["photo.png"] + ) response = _make_text_response("A beautiful photo") with patch("hokusai.providers.mistral.Mistral") as mock_cls: @@ -330,7 +336,7 @@ class TestMistralProvider: assert chunks[1].image_url.url.startswith("data:image/png;base64,") async def test_text_no_choices_raises(self, project_dir: Path) -> None: - target_config = TargetConfig(prompt="x") + target_config = GenerateTargetConfig(prompt="x") response = MagicMock() response.choices = [] @@ -348,7 +354,7 @@ class TestMistralProvider: ) async def test_text_empty_content_raises(self, project_dir: Path) -> None: - target_config = TargetConfig(prompt="x") + target_config = GenerateTargetConfig(prompt="x") response = _make_text_response(None) with patch("hokusai.providers.mistral.Mistral") as mock_cls: @@ -369,7 +375,7 @@ class TestMistralProvider: _ = (project_dir / "b.txt").write_text("content B") _ = (project_dir / "c.png").write_bytes(b"\x89PNG") - target_config = TargetConfig( + target_config = GenerateTargetConfig( prompt="Combine all", inputs=["a.txt", "b.txt", "c.png"] ) response = _make_text_response("Combined") @@ -400,7 +406,7 @@ class TestMistralProvider: async def test_text_with_reference_images(self, project_dir: Path) -> None: _ = (project_dir / "ref.png").write_bytes(b"\x89PNG") - target_config = TargetConfig( + target_config = GenerateTargetConfig( prompt="Describe the style", reference_images=["ref.png"] ) response = _make_text_response("A stylized image") @@ -455,7 +461,9 @@ class TestOpenAIImageProvider: ref = project_dir / "ref.png" _ = ref.write_bytes(b"reference data") - target_config = TargetConfig(prompt="Edit this", reference_images=["ref.png"]) + target_config = GenerateTargetConfig( + prompt="Edit this", reference_images=["ref.png"] + ) b64 = base64.b64encode(image_bytes).decode() mock_client = _make_openai_mock(b64) @@ -487,7 +495,7 @@ class TestOpenAIImageProvider: _ = ref1.write_bytes(b"ref1 data") _ = ref2.write_bytes(b"ref2 data") - target_config = TargetConfig( + target_config = GenerateTargetConfig( prompt="Combine", reference_images=["ref1.png", "ref2.png"] ) b64 = base64.b64encode(image_bytes).decode() diff --git a/tests/test_resolve.py b/tests/test_resolve.py index ad27574..b949794 100644 --- a/tests/test_resolve.py +++ b/tests/test_resolve.py @@ -6,7 +6,7 @@ import pytest from hokusai.config import ( Defaults, - TargetConfig, + GenerateTargetConfig, ) from hokusai.providers.models import Capability from hokusai.resolve import infer_required_capabilities, resolve_model @@ -16,37 +16,37 @@ class TestInferRequiredCapabilities: """Test capability inference from file extensions and target config.""" def test_plain_image(self) -> None: - target = TargetConfig(prompt="x") + target = GenerateTargetConfig(prompt="x") assert infer_required_capabilities("out.png", target) == frozenset( {Capability.TEXT_TO_IMAGE} ) @pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"]) def test_image_extensions(self, name: str) -> None: - target = TargetConfig(prompt="x") + target = GenerateTargetConfig(prompt="x") caps = infer_required_capabilities(name, target) assert Capability.TEXT_TO_IMAGE in caps @pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"]) def test_case_insensitive(self, name: str) -> None: - target = TargetConfig(prompt="x") + target = GenerateTargetConfig(prompt="x") caps = infer_required_capabilities(name, target) assert Capability.TEXT_TO_IMAGE in caps def test_image_with_reference_images(self) -> None: - target = TargetConfig(prompt="x", reference_images=["ref.png"]) + target = GenerateTargetConfig(prompt="x", reference_images=["ref.png"]) assert infer_required_capabilities("out.png", target) == frozenset( {Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES} ) def test_image_with_control_images(self) -> None: - target = TargetConfig(prompt="x", control_images=["ctrl.png"]) + target = GenerateTargetConfig(prompt="x", control_images=["ctrl.png"]) assert infer_required_capabilities("out.png", target) == frozenset( {Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES} ) def test_image_with_both(self) -> None: - target = TargetConfig( + target = GenerateTargetConfig( prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"] ) assert infer_required_capabilities("out.png", target) == frozenset( @@ -59,35 +59,35 @@ class TestInferRequiredCapabilities: @pytest.mark.parametrize("name", ["doc.md", "doc.txt"]) def test_text_extensions(self, name: str) -> None: - target = TargetConfig(prompt="x") + target = GenerateTargetConfig(prompt="x") caps = infer_required_capabilities(name, target) assert caps == frozenset({Capability.TEXT_GENERATION}) def test_text_with_text_inputs(self) -> None: - target = TargetConfig(prompt="x", inputs=["data.txt"]) + target = GenerateTargetConfig(prompt="x", inputs=["data.txt"]) assert infer_required_capabilities("out.md", target) == frozenset( {Capability.TEXT_GENERATION} ) def test_text_with_image_input(self) -> None: - target = TargetConfig(prompt="x", inputs=["photo.png"]) + target = GenerateTargetConfig(prompt="x", inputs=["photo.png"]) assert infer_required_capabilities("out.txt", target) == frozenset( {Capability.TEXT_GENERATION, Capability.VISION} ) def test_text_with_image_reference(self) -> None: - target = TargetConfig(prompt="x", reference_images=["ref.jpg"]) + target = GenerateTargetConfig(prompt="x", reference_images=["ref.jpg"]) assert infer_required_capabilities("out.md", target) == frozenset( {Capability.TEXT_GENERATION, Capability.VISION} ) def test_unsupported_extension_raises(self) -> None: - target = TargetConfig(prompt="x") + target = GenerateTargetConfig(prompt="x") with pytest.raises(ValueError, match="unsupported extension"): _ = infer_required_capabilities("data.csv", target) def test_no_extension_raises(self) -> None: - target = TargetConfig(prompt="x") + target = GenerateTargetConfig(prompt="x") with pytest.raises(ValueError, match="unsupported extension"): _ = infer_required_capabilities("Makefile", target) @@ -96,50 +96,52 @@ class TestResolveModel: """Test model resolution with capability validation.""" def test_explicit_model_wins(self) -> None: - target = TargetConfig(prompt="x", model="mistral-small-latest") + target = GenerateTargetConfig(prompt="x", model="mistral-small-latest") result = resolve_model("out.txt", target, Defaults()) assert result.name == "mistral-small-latest" def test_default_text_model(self) -> None: - target = TargetConfig(prompt="x") + target = GenerateTargetConfig(prompt="x") defaults = Defaults(text_model="mistral-large-latest") result = resolve_model("out.md", target, defaults) assert result.name == "mistral-large-latest" def test_default_image_model(self) -> None: - target = TargetConfig(prompt="x") + target = GenerateTargetConfig(prompt="x") defaults = Defaults(image_model="flux-dev") result = resolve_model("out.png", target, defaults) assert result.name == "flux-dev" def test_unknown_model_raises(self) -> None: - target = TargetConfig(prompt="x", model="nonexistent-model") + target = GenerateTargetConfig(prompt="x", model="nonexistent-model") with pytest.raises(ValueError, match="Unknown model"): _ = resolve_model("out.txt", target, Defaults()) def test_explicit_model_missing_capability_raises(self) -> None: # flux-dev does not support reference images - target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"]) + target = GenerateTargetConfig( + prompt="x", model="flux-dev", reference_images=["r.png"] + ) with pytest.raises(ValueError, match="lacks required capabilities"): _ = resolve_model("out.png", target, Defaults()) def test_default_fallback_for_reference_images(self) -> None: # Default flux-dev lacks reference_images, should fall back to a capable model - target = TargetConfig(prompt="x", reference_images=["r.png"]) + target = GenerateTargetConfig(prompt="x", reference_images=["r.png"]) defaults = Defaults(image_model="flux-dev") result = resolve_model("out.png", target, defaults) assert Capability.REFERENCE_IMAGES in result.capabilities def test_default_fallback_for_vision(self) -> None: # Default mistral-large-latest lacks vision, should fall back to a pixtral model - target = TargetConfig(prompt="x", inputs=["photo.png"]) + target = GenerateTargetConfig(prompt="x", inputs=["photo.png"]) defaults = Defaults(text_model="mistral-large-latest") result = resolve_model("out.txt", target, defaults) assert Capability.VISION in result.capabilities def test_default_preferred_when_capable(self) -> None: # Default flux-2-pro already supports reference_images, should be used directly - target = TargetConfig(prompt="x", reference_images=["r.png"]) + target = GenerateTargetConfig(prompt="x", reference_images=["r.png"]) defaults = Defaults(image_model="flux-2-pro") result = resolve_model("out.png", target, defaults) assert result.name == "flux-2-pro"