feat: add download target type for fetching files from URLs

This commit is contained in:
Konstantin Fickel 2026-02-20 21:02:15 +01:00
parent a4600df4d5
commit c1ad6e6e3c
Signed by: kfickel
GPG key ID: A793722F9933C1A5
14 changed files with 296 additions and 74 deletions

View file

@ -9,7 +9,9 @@ from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from hokusai.config import ProjectConfig
import httpx
from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
from hokusai.providers import Provider
from hokusai.providers.blackforest import BlackForestProvider
@ -73,6 +75,8 @@ def _collect_dep_files(
) -> list[Path]:
"""Collect all dependency file paths for a target."""
target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return []
deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images)
@ -82,6 +86,8 @@ def _collect_dep_files(
def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]:
"""Collect extra parameters that affect rebuild decisions."""
target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return {}
params: dict[str, object] = {}
if target_cfg.width is not None:
params["width"] = target_cfg.width
@ -97,6 +103,8 @@ def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str,
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
"""Collect all dependency names (inputs + reference_images + control_images)."""
target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return []
deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images)
@ -128,6 +136,18 @@ def _build_provider_index(providers: list[Provider]) -> dict[str, Provider]:
return index
async def _download_target(
target_name: str,
target_cfg: DownloadTargetConfig,
project_dir: Path,
) -> None:
"""Download a file from a URL and write it to *project_dir / target_name*."""
async with httpx.AsyncClient(headers={"User-Agent": "hokusai"}) as client:
response = await client.get(target_cfg.download, follow_redirects=True)
_ = response.raise_for_status()
_ = (project_dir / target_name).write_bytes(response.content)
async def _build_single_target(
target_name: str,
config: ProjectConfig,
@ -136,6 +156,11 @@ async def _build_single_target(
) -> None:
"""Build a single target by dispatching to the appropriate provider."""
target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
await _download_target(target_name, target_cfg, project_dir)
return
model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
@ -227,6 +252,9 @@ def _should_skip_failed_dep(
return any(d in result.failed for d in _collect_all_deps(target_name, config))
_DOWNLOAD_MODEL_SENTINEL = "__download__"
def _is_dirty(
target_name: str,
config: ProjectConfig,
@ -235,6 +263,18 @@ def _is_dirty(
) -> bool:
"""Check if a target needs rebuilding."""
target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
return is_target_dirty(
target_name,
resolved_prompt=target_cfg.download,
model=_DOWNLOAD_MODEL_SENTINEL,
dep_files=[],
extra_params={},
state=state,
project_dir=project_dir,
)
model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(target_name, config, project_dir)
@ -260,6 +300,8 @@ def _has_provider(
) -> bool:
"""Check that the required provider is available; record failure if not."""
target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
return True
model_info = resolve_model(target_name, target_cfg, config.defaults)
if model_info.name not in provider_index:
msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables"
@ -302,15 +344,22 @@ def _process_outcomes(
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
else:
target_cfg = config.targets[name]
model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
if isinstance(target_cfg, DownloadTargetConfig):
resolved_prompt = target_cfg.download
model_name = _DOWNLOAD_MODEL_SENTINEL
else:
model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
model_name = model_info.name
dep_files = _collect_dep_files(name, config, project_dir)
extra = _collect_extra_params(name, config)
record_target_state(
name,
resolved_prompt=resolved_prompt,
model=model_info.name,
model=model_name,
dep_files=dep_files,
extra_params=extra,
state=state,

View file

@ -257,18 +257,12 @@ def init() -> None:
content = f"""\
# {name} - hokusai project
defaults:
image_model: flux-2-pro
targets:
great_wave.png:
prompt: >-
A recreation of Hokusai's "The Great Wave off Kanagawa", but instead of
boats and people, paint brushes, canvases, and framed paintings are
swimming and tumbling in the towering wave. Oil paint tubes burst open
and trail ribbons of colour through the spray. The iconic Mount Fuji
sits serenely in the background. Ukiyo-e woodblock print style with
vivid modern pigment colours.
swimming and tumbling in the towering wave.
"""
_ = dest.write_text(content)
click.echo(click.style(" created ", fg="green") + click.style(filename, bold=True))

View file

@ -4,10 +4,10 @@ from __future__ import annotations
import enum
from pathlib import Path
from typing import Self
from typing import Annotated, Self
import yaml
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Discriminator, Tag, model_validator
from hokusai.providers.models import Capability
@ -29,8 +29,8 @@ class Defaults(BaseModel):
image_model: str = "flux-2-pro"
class TargetConfig(BaseModel):
"""Configuration for a single build target."""
class GenerateTargetConfig(BaseModel):
"""Configuration for a target that generates an artifact via an AI provider."""
prompt: str
model: str | None = None
@ -41,6 +41,26 @@ class TargetConfig(BaseModel):
height: int | None = None
class DownloadTargetConfig(BaseModel):
"""Configuration for a target that downloads a file from a URL."""
download: str
def _target_discriminator(raw: object) -> str:
"""Discriminate between generate and download target configs."""
if isinstance(raw, dict) and "download" in raw:
return "download"
return "generate"
TargetConfig = Annotated[
Annotated[GenerateTargetConfig, Tag("generate")]
| Annotated[DownloadTargetConfig, Tag("download")],
Discriminator(_target_discriminator),
]
class ProjectConfig(BaseModel):
"""Top-level configuration parsed from ``<name>.hokusai.yaml``."""

View file

@ -6,7 +6,7 @@ from pathlib import Path
import networkx as nx
from hokusai.config import ProjectConfig
from hokusai.config import GenerateTargetConfig, ProjectConfig
def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
@ -25,6 +25,9 @@ def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
for target_name, target_cfg in config.targets.items():
graph.add_node(target_name)
if not isinstance(target_cfg, GenerateTargetConfig):
continue
deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images)

View file

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
from hokusai.providers.models import ModelInfo
if TYPE_CHECKING:
from hokusai.config import TargetConfig
from hokusai.config import GenerateTargetConfig
class Provider(abc.ABC):
@ -24,7 +24,7 @@ class Provider(abc.ABC):
async def generate(
self,
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,

View file

@ -8,7 +8,7 @@ from typing import override
import httpx
from hokusai.config import TargetConfig
from hokusai.config import GenerateTargetConfig
from hokusai.providers import Provider
from hokusai.providers.bfl import BFLClient
from hokusai.providers.models import Capability, ModelInfo
@ -123,7 +123,7 @@ class BlackForestProvider(Provider):
async def generate(
self,
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,

View file

@ -9,7 +9,7 @@ from typing import override
from mistralai import Mistral, models
from hokusai.config import IMAGE_EXTENSIONS, TargetConfig
from hokusai.config import IMAGE_EXTENSIONS, GenerateTargetConfig
from hokusai.providers import Provider
from hokusai.providers.models import Capability, ModelInfo
@ -63,7 +63,7 @@ class MistralProvider(Provider):
async def generate(
self,
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,

View file

@ -10,7 +10,7 @@ import httpx
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from hokusai.config import TargetConfig
from hokusai.config import GenerateTargetConfig
from hokusai.providers import Provider
from hokusai.providers.models import Capability, ModelInfo
@ -109,7 +109,7 @@ class OpenAIImageProvider(Provider):
async def generate(
self,
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,

View file

@ -15,7 +15,7 @@ from openai.types.chat import (
ChatCompletionUserMessageParam,
)
from hokusai.config import IMAGE_EXTENSIONS, TargetConfig
from hokusai.config import IMAGE_EXTENSIONS, GenerateTargetConfig
from hokusai.providers import Provider
from hokusai.providers.models import Capability, ModelInfo
@ -120,7 +120,7 @@ class OpenAITextProvider(Provider):
async def generate(
self,
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,

View file

@ -6,7 +6,7 @@ from hokusai.config import (
IMAGE_EXTENSIONS,
TEXT_EXTENSIONS,
Defaults,
TargetConfig,
GenerateTargetConfig,
TargetType,
target_type_from_capabilities,
)
@ -14,7 +14,7 @@ from hokusai.providers.models import Capability, ModelInfo
def infer_required_capabilities(
target_name: str, target: TargetConfig
target_name: str, target: GenerateTargetConfig
) -> frozenset[Capability]:
"""Infer the capabilities a model must have based on filename and config.
@ -42,7 +42,7 @@ def infer_required_capabilities(
def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults
target_name: str, target: GenerateTargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target, validated against required capabilities.

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
from typing import override
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -16,7 +16,7 @@ from hokusai.builder import (
_resolve_prompt, # pyright: ignore[reportPrivateUsage]
run_build,
)
from hokusai.config import ProjectConfig, TargetConfig
from hokusai.config import GenerateTargetConfig, ProjectConfig
from hokusai.providers import Provider
from hokusai.providers.models import Capability, ModelInfo
from hokusai.state import load_state
@ -57,7 +57,7 @@ class FakeTextProvider(Provider):
async def generate(
self,
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,
@ -78,7 +78,7 @@ class FakeImageProvider(Provider):
async def generate(
self,
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,
@ -99,7 +99,7 @@ class FailingTextProvider(Provider):
async def generate(
self,
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,
@ -295,7 +295,7 @@ class TestRunBuild:
async def selective_generate(
target_name: str,
target_config: TargetConfig,
target_config: GenerateTargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,
@ -426,3 +426,145 @@ class TestRunBuild:
assert set(result.built) == {"left.md", "right.md", "merge.txt"}
assert result.failed == {}
class TestDownloadTarget:
"""Tests for download-type targets that fetch files from URLs."""
async def test_download_target_fetches_url(
self, project_dir: Path, write_config: WriteConfig
) -> None:
config = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish.png"},
}
}
)
mock_response = MagicMock()
mock_response.content = b"fake image bytes"
mock_response.raise_for_status = MagicMock()
with (
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await run_build(config, project_dir, _PROJECT)
assert result.built == ["fish.png"]
assert (project_dir / "fish.png").read_bytes() == b"fake image bytes"
mock_client.get.assert_called_once_with( # pyright: ignore[reportAny]
"https://example.com/fish.png", follow_redirects=True
)
async def test_download_target_incremental_skip(
self, project_dir: Path, write_config: WriteConfig
) -> None:
config = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish.png"},
}
}
)
mock_response = MagicMock()
mock_response.content = b"fake image bytes"
mock_response.raise_for_status = MagicMock()
with (
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
r1 = await run_build(config, project_dir, _PROJECT)
assert r1.built == ["fish.png"]
r2 = await run_build(config, project_dir, _PROJECT)
assert r2.skipped == ["fish.png"]
assert r2.built == []
async def test_download_target_rebuild_on_url_change(
self, project_dir: Path, write_config: WriteConfig
) -> None:
config1 = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish-v1.png"},
}
}
)
mock_response = MagicMock()
mock_response.content = b"v1 bytes"
mock_response.raise_for_status = MagicMock()
with (
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
r1 = await run_build(config1, project_dir, _PROJECT)
assert r1.built == ["fish.png"]
config2 = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish-v2.png"},
}
}
)
mock_response.content = b"v2 bytes"
r2 = await run_build(config2, project_dir, _PROJECT)
assert r2.built == ["fish.png"]
assert (project_dir / "fish.png").read_bytes() == b"v2 bytes"
async def test_download_target_as_dependency(
self, project_dir: Path, write_config: WriteConfig
) -> None:
config = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish.png"},
"description.txt": {
"prompt": "Describe the fish",
"inputs": ["fish.png"],
},
}
}
)
mock_response = MagicMock()
mock_response.content = b"fake fish image"
mock_response.raise_for_status = MagicMock()
with (
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await run_build(config, project_dir, _PROJECT)
assert "fish.png" in result.built
assert "description.txt" in result.built
assert (project_dir / "fish.png").read_bytes() == b"fake fish image"
assert (project_dir / "description.txt").exists()

View file

@ -7,7 +7,7 @@ from pathlib import Path
import pytest
import yaml
from hokusai.config import load_config
from hokusai.config import GenerateTargetConfig, load_config
class TestLoadConfig:
@ -21,7 +21,9 @@ class TestLoadConfig:
config = load_config(config_path)
assert "out.txt" in config.targets
assert config.targets["out.txt"].prompt == "hello"
target = config.targets["out.txt"]
assert isinstance(target, GenerateTargetConfig)
assert target.prompt == "hello"
assert config.defaults.text_model == "pixtral-large-latest"
assert config.defaults.image_model == "flux-2-pro"
@ -56,6 +58,7 @@ class TestLoadConfig:
assert config.defaults.image_model == "custom-image"
banner = config.targets["banner.png"]
assert isinstance(banner, GenerateTargetConfig)
assert banner.model == "flux-dev"
assert banner.width == 1920
assert banner.height == 480
@ -63,6 +66,7 @@ class TestLoadConfig:
assert banner.control_images == ["ctrl.png"]
story = config.targets["story.md"]
assert isinstance(story, GenerateTargetConfig)
assert story.model is None
assert story.inputs == ["banner.png"]

View file

@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from hokusai.config import TargetConfig
from hokusai.config import GenerateTargetConfig
from hokusai.providers.bfl import BFLResult
from hokusai.providers.blackforest import BlackForestProvider
from hokusai.providers.blackforest import (
@ -81,7 +81,7 @@ class TestBlackForestProvider:
async def test_basic_image_generation(
self, project_dir: Path, image_bytes: bytes
) -> None:
target_config = TargetConfig(prompt="A red square")
target_config = GenerateTargetConfig(prompt="A red square")
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
@ -107,7 +107,7 @@ class TestBlackForestProvider:
async def test_image_with_dimensions(
self, project_dir: Path, image_bytes: bytes
) -> None:
target_config = TargetConfig(prompt="A banner", width=1920, height=480)
target_config = GenerateTargetConfig(prompt="A banner", width=1920, height=480)
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
@ -138,7 +138,9 @@ class TestBlackForestProvider:
ref_path = project_dir / "ref.png"
_ = ref_path.write_bytes(b"reference image data")
target_config = TargetConfig(prompt="Like this", reference_images=["ref.png"])
target_config = GenerateTargetConfig(
prompt="Like this", reference_images=["ref.png"]
)
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
@ -171,7 +173,7 @@ class TestBlackForestProvider:
_ = ref1.write_bytes(b"ref1 data")
_ = ref2.write_bytes(b"ref2 data")
target_config = TargetConfig(
target_config = GenerateTargetConfig(
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
)
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
@ -204,7 +206,9 @@ class TestBlackForestProvider:
ref_path = project_dir / "ref.png"
_ = ref_path.write_bytes(b"reference image data")
target_config = TargetConfig(prompt="Edit", reference_images=["ref.png"])
target_config = GenerateTargetConfig(
prompt="Edit", reference_images=["ref.png"]
)
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
@ -229,7 +233,7 @@ class TestBlackForestProvider:
assert inputs["input_image"] == encode_image_b64(ref_path)
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x")
target_config = GenerateTargetConfig(prompt="x")
with patch("hokusai.providers.blackforest.BFLClient") as mock_cls:
from hokusai.providers.bfl import BFLError
@ -261,7 +265,7 @@ class TestMistralProvider:
"""Test MistralProvider with mocked Mistral client."""
async def test_basic_text_generation(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="Write a poem")
target_config = GenerateTargetConfig(prompt="Write a poem")
response = _make_text_response("Roses are red...")
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
@ -282,7 +286,7 @@ class TestMistralProvider:
async def test_text_with_text_input(self, project_dir: Path) -> None:
_ = (project_dir / "source.txt").write_text("Source material here")
target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"])
target_config = GenerateTargetConfig(prompt="Summarize", inputs=["source.txt"])
response = _make_text_response("Summary: ...")
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
@ -306,7 +310,9 @@ class TestMistralProvider:
async def test_text_with_image_input(self, project_dir: Path) -> None:
_ = (project_dir / "photo.png").write_bytes(b"\x89PNG")
target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"])
target_config = GenerateTargetConfig(
prompt="Describe this image", inputs=["photo.png"]
)
response = _make_text_response("A beautiful photo")
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
@ -330,7 +336,7 @@ class TestMistralProvider:
assert chunks[1].image_url.url.startswith("data:image/png;base64,")
async def test_text_no_choices_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x")
target_config = GenerateTargetConfig(prompt="x")
response = MagicMock()
response.choices = []
@ -348,7 +354,7 @@ class TestMistralProvider:
)
async def test_text_empty_content_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x")
target_config = GenerateTargetConfig(prompt="x")
response = _make_text_response(None)
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
@ -369,7 +375,7 @@ class TestMistralProvider:
_ = (project_dir / "b.txt").write_text("content B")
_ = (project_dir / "c.png").write_bytes(b"\x89PNG")
target_config = TargetConfig(
target_config = GenerateTargetConfig(
prompt="Combine all", inputs=["a.txt", "b.txt", "c.png"]
)
response = _make_text_response("Combined")
@ -400,7 +406,7 @@ class TestMistralProvider:
async def test_text_with_reference_images(self, project_dir: Path) -> None:
_ = (project_dir / "ref.png").write_bytes(b"\x89PNG")
target_config = TargetConfig(
target_config = GenerateTargetConfig(
prompt="Describe the style", reference_images=["ref.png"]
)
response = _make_text_response("A stylized image")
@ -455,7 +461,9 @@ class TestOpenAIImageProvider:
ref = project_dir / "ref.png"
_ = ref.write_bytes(b"reference data")
target_config = TargetConfig(prompt="Edit this", reference_images=["ref.png"])
target_config = GenerateTargetConfig(
prompt="Edit this", reference_images=["ref.png"]
)
b64 = base64.b64encode(image_bytes).decode()
mock_client = _make_openai_mock(b64)
@ -487,7 +495,7 @@ class TestOpenAIImageProvider:
_ = ref1.write_bytes(b"ref1 data")
_ = ref2.write_bytes(b"ref2 data")
target_config = TargetConfig(
target_config = GenerateTargetConfig(
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
)
b64 = base64.b64encode(image_bytes).decode()

View file

@ -6,7 +6,7 @@ import pytest
from hokusai.config import (
Defaults,
TargetConfig,
GenerateTargetConfig,
)
from hokusai.providers.models import Capability
from hokusai.resolve import infer_required_capabilities, resolve_model
@ -16,37 +16,37 @@ class TestInferRequiredCapabilities:
"""Test capability inference from file extensions and target config."""
def test_plain_image(self) -> None:
target = TargetConfig(prompt="x")
target = GenerateTargetConfig(prompt="x")
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE}
)
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
def test_image_extensions(self, name: str) -> None:
target = TargetConfig(prompt="x")
target = GenerateTargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
def test_case_insensitive(self, name: str) -> None:
target = TargetConfig(prompt="x")
target = GenerateTargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps
def test_image_with_reference_images(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.png"])
target = GenerateTargetConfig(prompt="x", reference_images=["ref.png"])
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
)
def test_image_with_control_images(self) -> None:
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
target = GenerateTargetConfig(prompt="x", control_images=["ctrl.png"])
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
)
def test_image_with_both(self) -> None:
target = TargetConfig(
target = GenerateTargetConfig(
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
)
assert infer_required_capabilities("out.png", target) == frozenset(
@ -59,35 +59,35 @@ class TestInferRequiredCapabilities:
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
def test_text_extensions(self, name: str) -> None:
target = TargetConfig(prompt="x")
target = GenerateTargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert caps == frozenset({Capability.TEXT_GENERATION})
def test_text_with_text_inputs(self) -> None:
target = TargetConfig(prompt="x", inputs=["data.txt"])
target = GenerateTargetConfig(prompt="x", inputs=["data.txt"])
assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION}
)
def test_text_with_image_input(self) -> None:
target = TargetConfig(prompt="x", inputs=["photo.png"])
target = GenerateTargetConfig(prompt="x", inputs=["photo.png"])
assert infer_required_capabilities("out.txt", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION}
)
def test_text_with_image_reference(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
target = GenerateTargetConfig(prompt="x", reference_images=["ref.jpg"])
assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION}
)
def test_unsupported_extension_raises(self) -> None:
target = TargetConfig(prompt="x")
target = GenerateTargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_required_capabilities("data.csv", target)
def test_no_extension_raises(self) -> None:
target = TargetConfig(prompt="x")
target = GenerateTargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_required_capabilities("Makefile", target)
@ -96,50 +96,52 @@ class TestResolveModel:
"""Test model resolution with capability validation."""
def test_explicit_model_wins(self) -> None:
target = TargetConfig(prompt="x", model="mistral-small-latest")
target = GenerateTargetConfig(prompt="x", model="mistral-small-latest")
result = resolve_model("out.txt", target, Defaults())
assert result.name == "mistral-small-latest"
def test_default_text_model(self) -> None:
target = TargetConfig(prompt="x")
target = GenerateTargetConfig(prompt="x")
defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.md", target, defaults)
assert result.name == "mistral-large-latest"
def test_default_image_model(self) -> None:
target = TargetConfig(prompt="x")
target = GenerateTargetConfig(prompt="x")
defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults)
assert result.name == "flux-dev"
def test_unknown_model_raises(self) -> None:
target = TargetConfig(prompt="x", model="nonexistent-model")
target = GenerateTargetConfig(prompt="x", model="nonexistent-model")
with pytest.raises(ValueError, match="Unknown model"):
_ = resolve_model("out.txt", target, Defaults())
def test_explicit_model_missing_capability_raises(self) -> None:
# flux-dev does not support reference images
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
target = GenerateTargetConfig(
prompt="x", model="flux-dev", reference_images=["r.png"]
)
with pytest.raises(ValueError, match="lacks required capabilities"):
_ = resolve_model("out.png", target, Defaults())
def test_default_fallback_for_reference_images(self) -> None:
# Default flux-dev lacks reference_images, should fall back to a capable model
target = TargetConfig(prompt="x", reference_images=["r.png"])
target = GenerateTargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults)
assert Capability.REFERENCE_IMAGES in result.capabilities
def test_default_fallback_for_vision(self) -> None:
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
target = TargetConfig(prompt="x", inputs=["photo.png"])
target = GenerateTargetConfig(prompt="x", inputs=["photo.png"])
defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.txt", target, defaults)
assert Capability.VISION in result.capabilities
def test_default_preferred_when_capable(self) -> None:
# Default flux-2-pro already supports reference_images, should be used directly
target = TargetConfig(prompt="x", reference_images=["r.png"])
target = GenerateTargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-2-pro")
result = resolve_model("out.png", target, defaults)
assert result.name == "flux-2-pro"