feat: add download target type for fetching files from URLs

This commit is contained in:
Konstantin Fickel 2026-02-20 21:02:15 +01:00
parent a4600df4d5
commit c1ad6e6e3c
Signed by: kfickel
GPG key ID: A793722F9933C1A5
14 changed files with 296 additions and 74 deletions

View file

@ -9,7 +9,9 @@ from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from hokusai.config import ProjectConfig import httpx
from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
from hokusai.providers import Provider from hokusai.providers import Provider
from hokusai.providers.blackforest import BlackForestProvider from hokusai.providers.blackforest import BlackForestProvider
@ -73,6 +75,8 @@ def _collect_dep_files(
) -> list[Path]: ) -> list[Path]:
"""Collect all dependency file paths for a target.""" """Collect all dependency file paths for a target."""
target_cfg = config.targets[target_name] target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return []
deps: list[str] = list(target_cfg.inputs) deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images) deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images) deps.extend(target_cfg.control_images)
@ -82,6 +86,8 @@ def _collect_dep_files(
def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]: def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]:
"""Collect extra parameters that affect rebuild decisions.""" """Collect extra parameters that affect rebuild decisions."""
target_cfg = config.targets[target_name] target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return {}
params: dict[str, object] = {} params: dict[str, object] = {}
if target_cfg.width is not None: if target_cfg.width is not None:
params["width"] = target_cfg.width params["width"] = target_cfg.width
@ -97,6 +103,8 @@ def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str,
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]: def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
"""Collect all dependency names (inputs + reference_images + control_images).""" """Collect all dependency names (inputs + reference_images + control_images)."""
target_cfg = config.targets[target_name] target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return []
deps: list[str] = list(target_cfg.inputs) deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images) deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images) deps.extend(target_cfg.control_images)
@ -128,6 +136,18 @@ def _build_provider_index(providers: list[Provider]) -> dict[str, Provider]:
return index return index
async def _download_target(
target_name: str,
target_cfg: DownloadTargetConfig,
project_dir: Path,
) -> None:
"""Download a file from a URL and write it to *project_dir / target_name*."""
async with httpx.AsyncClient(headers={"User-Agent": "hokusai"}) as client:
response = await client.get(target_cfg.download, follow_redirects=True)
_ = response.raise_for_status()
_ = (project_dir / target_name).write_bytes(response.content)
async def _build_single_target( async def _build_single_target(
target_name: str, target_name: str,
config: ProjectConfig, config: ProjectConfig,
@ -136,6 +156,11 @@ async def _build_single_target(
) -> None: ) -> None:
"""Build a single target by dispatching to the appropriate provider.""" """Build a single target by dispatching to the appropriate provider."""
target_cfg = config.targets[target_name] target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
await _download_target(target_name, target_cfg, project_dir)
return
model_info = resolve_model(target_name, target_cfg, config.defaults) model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
@ -227,6 +252,9 @@ def _should_skip_failed_dep(
return any(d in result.failed for d in _collect_all_deps(target_name, config)) return any(d in result.failed for d in _collect_all_deps(target_name, config))
_DOWNLOAD_MODEL_SENTINEL = "__download__"
def _is_dirty( def _is_dirty(
target_name: str, target_name: str,
config: ProjectConfig, config: ProjectConfig,
@ -235,6 +263,18 @@ def _is_dirty(
) -> bool: ) -> bool:
"""Check if a target needs rebuilding.""" """Check if a target needs rebuilding."""
target_cfg = config.targets[target_name] target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
return is_target_dirty(
target_name,
resolved_prompt=target_cfg.download,
model=_DOWNLOAD_MODEL_SENTINEL,
dep_files=[],
extra_params={},
state=state,
project_dir=project_dir,
)
model_info = resolve_model(target_name, target_cfg, config.defaults) model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(target_name, config, project_dir) dep_files = _collect_dep_files(target_name, config, project_dir)
@ -260,6 +300,8 @@ def _has_provider(
) -> bool: ) -> bool:
"""Check that the required provider is available; record failure if not.""" """Check that the required provider is available; record failure if not."""
target_cfg = config.targets[target_name] target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
return True
model_info = resolve_model(target_name, target_cfg, config.defaults) model_info = resolve_model(target_name, target_cfg, config.defaults)
if model_info.name not in provider_index: if model_info.name not in provider_index:
msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables" msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables"
@ -302,15 +344,22 @@ def _process_outcomes(
on_progress(BuildEvent.TARGET_FAILED, name, str(error)) on_progress(BuildEvent.TARGET_FAILED, name, str(error))
else: else:
target_cfg = config.targets[name] target_cfg = config.targets[name]
model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) if isinstance(target_cfg, DownloadTargetConfig):
resolved_prompt = target_cfg.download
model_name = _DOWNLOAD_MODEL_SENTINEL
else:
model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
model_name = model_info.name
dep_files = _collect_dep_files(name, config, project_dir) dep_files = _collect_dep_files(name, config, project_dir)
extra = _collect_extra_params(name, config) extra = _collect_extra_params(name, config)
record_target_state( record_target_state(
name, name,
resolved_prompt=resolved_prompt, resolved_prompt=resolved_prompt,
model=model_info.name, model=model_name,
dep_files=dep_files, dep_files=dep_files,
extra_params=extra, extra_params=extra,
state=state, state=state,

View file

@ -257,18 +257,12 @@ def init() -> None:
content = f"""\ content = f"""\
# {name} - hokusai project # {name} - hokusai project
defaults:
image_model: flux-2-pro
targets: targets:
great_wave.png: great_wave.png:
prompt: >- prompt: >-
A recreation of Hokusai's "The Great Wave off Kanagawa", but instead of A recreation of Hokusai's "The Great Wave off Kanagawa", but instead of
boats and people, paint brushes, canvases, and framed paintings are boats and people, paint brushes, canvases, and framed paintings are
swimming and tumbling in the towering wave. Oil paint tubes burst open swimming and tumbling in the towering wave.
and trail ribbons of colour through the spray. The iconic Mount Fuji
sits serenely in the background. Ukiyo-e woodblock print style with
vivid modern pigment colours.
""" """
_ = dest.write_text(content) _ = dest.write_text(content)
click.echo(click.style(" created ", fg="green") + click.style(filename, bold=True)) click.echo(click.style(" created ", fg="green") + click.style(filename, bold=True))

View file

@ -4,10 +4,10 @@ from __future__ import annotations
import enum import enum
from pathlib import Path from pathlib import Path
from typing import Self from typing import Annotated, Self
import yaml import yaml
from pydantic import BaseModel, model_validator from pydantic import BaseModel, Discriminator, Tag, model_validator
from hokusai.providers.models import Capability from hokusai.providers.models import Capability
@ -29,8 +29,8 @@ class Defaults(BaseModel):
image_model: str = "flux-2-pro" image_model: str = "flux-2-pro"
class TargetConfig(BaseModel): class GenerateTargetConfig(BaseModel):
"""Configuration for a single build target.""" """Configuration for a target that generates an artifact via an AI provider."""
prompt: str prompt: str
model: str | None = None model: str | None = None
@ -41,6 +41,26 @@ class TargetConfig(BaseModel):
height: int | None = None height: int | None = None
class DownloadTargetConfig(BaseModel):
"""Configuration for a target that downloads a file from a URL."""
download: str
def _target_discriminator(raw: object) -> str:
"""Discriminate between generate and download target configs."""
if isinstance(raw, dict) and "download" in raw:
return "download"
return "generate"
TargetConfig = Annotated[
Annotated[GenerateTargetConfig, Tag("generate")]
| Annotated[DownloadTargetConfig, Tag("download")],
Discriminator(_target_discriminator),
]
class ProjectConfig(BaseModel): class ProjectConfig(BaseModel):
"""Top-level configuration parsed from ``<name>.hokusai.yaml``.""" """Top-level configuration parsed from ``<name>.hokusai.yaml``."""

View file

@ -6,7 +6,7 @@ from pathlib import Path
import networkx as nx import networkx as nx
from hokusai.config import ProjectConfig from hokusai.config import GenerateTargetConfig, ProjectConfig
def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]: def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
@ -25,6 +25,9 @@ def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
for target_name, target_cfg in config.targets.items(): for target_name, target_cfg in config.targets.items():
graph.add_node(target_name) graph.add_node(target_name)
if not isinstance(target_cfg, GenerateTargetConfig):
continue
deps: list[str] = list(target_cfg.inputs) deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images) deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images) deps.extend(target_cfg.control_images)

View file

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
from hokusai.providers.models import ModelInfo from hokusai.providers.models import ModelInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from hokusai.config import TargetConfig from hokusai.config import GenerateTargetConfig
class Provider(abc.ABC): class Provider(abc.ABC):
@ -24,7 +24,7 @@ class Provider(abc.ABC):
async def generate( async def generate(
self, self,
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,

View file

@ -8,7 +8,7 @@ from typing import override
import httpx import httpx
from hokusai.config import TargetConfig from hokusai.config import GenerateTargetConfig
from hokusai.providers import Provider from hokusai.providers import Provider
from hokusai.providers.bfl import BFLClient from hokusai.providers.bfl import BFLClient
from hokusai.providers.models import Capability, ModelInfo from hokusai.providers.models import Capability, ModelInfo
@ -123,7 +123,7 @@ class BlackForestProvider(Provider):
async def generate( async def generate(
self, self,
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,

View file

@ -9,7 +9,7 @@ from typing import override
from mistralai import Mistral, models from mistralai import Mistral, models
from hokusai.config import IMAGE_EXTENSIONS, TargetConfig from hokusai.config import IMAGE_EXTENSIONS, GenerateTargetConfig
from hokusai.providers import Provider from hokusai.providers import Provider
from hokusai.providers.models import Capability, ModelInfo from hokusai.providers.models import Capability, ModelInfo
@ -63,7 +63,7 @@ class MistralProvider(Provider):
async def generate( async def generate(
self, self,
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,

View file

@ -10,7 +10,7 @@ import httpx
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse from openai.types.images_response import ImagesResponse
from hokusai.config import TargetConfig from hokusai.config import GenerateTargetConfig
from hokusai.providers import Provider from hokusai.providers import Provider
from hokusai.providers.models import Capability, ModelInfo from hokusai.providers.models import Capability, ModelInfo
@ -109,7 +109,7 @@ class OpenAIImageProvider(Provider):
async def generate( async def generate(
self, self,
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,

View file

@ -15,7 +15,7 @@ from openai.types.chat import (
ChatCompletionUserMessageParam, ChatCompletionUserMessageParam,
) )
from hokusai.config import IMAGE_EXTENSIONS, TargetConfig from hokusai.config import IMAGE_EXTENSIONS, GenerateTargetConfig
from hokusai.providers import Provider from hokusai.providers import Provider
from hokusai.providers.models import Capability, ModelInfo from hokusai.providers.models import Capability, ModelInfo
@ -120,7 +120,7 @@ class OpenAITextProvider(Provider):
async def generate( async def generate(
self, self,
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,

View file

@ -6,7 +6,7 @@ from hokusai.config import (
IMAGE_EXTENSIONS, IMAGE_EXTENSIONS,
TEXT_EXTENSIONS, TEXT_EXTENSIONS,
Defaults, Defaults,
TargetConfig, GenerateTargetConfig,
TargetType, TargetType,
target_type_from_capabilities, target_type_from_capabilities,
) )
@ -14,7 +14,7 @@ from hokusai.providers.models import Capability, ModelInfo
def infer_required_capabilities( def infer_required_capabilities(
target_name: str, target: TargetConfig target_name: str, target: GenerateTargetConfig
) -> frozenset[Capability]: ) -> frozenset[Capability]:
"""Infer the capabilities a model must have based on filename and config. """Infer the capabilities a model must have based on filename and config.
@ -42,7 +42,7 @@ def infer_required_capabilities(
def resolve_model( def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults target_name: str, target: GenerateTargetConfig, defaults: Defaults
) -> ModelInfo: ) -> ModelInfo:
"""Return the effective model for a target, validated against required capabilities. """Return the effective model for a target, validated against required capabilities.

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import override from typing import override
from unittest.mock import patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -16,7 +16,7 @@ from hokusai.builder import (
_resolve_prompt, # pyright: ignore[reportPrivateUsage] _resolve_prompt, # pyright: ignore[reportPrivateUsage]
run_build, run_build,
) )
from hokusai.config import ProjectConfig, TargetConfig from hokusai.config import GenerateTargetConfig, ProjectConfig
from hokusai.providers import Provider from hokusai.providers import Provider
from hokusai.providers.models import Capability, ModelInfo from hokusai.providers.models import Capability, ModelInfo
from hokusai.state import load_state from hokusai.state import load_state
@ -57,7 +57,7 @@ class FakeTextProvider(Provider):
async def generate( async def generate(
self, self,
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
@ -78,7 +78,7 @@ class FakeImageProvider(Provider):
async def generate( async def generate(
self, self,
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
@ -99,7 +99,7 @@ class FailingTextProvider(Provider):
async def generate( async def generate(
self, self,
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
@ -295,7 +295,7 @@ class TestRunBuild:
async def selective_generate( async def selective_generate(
target_name: str, target_name: str,
target_config: TargetConfig, target_config: GenerateTargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: ModelInfo, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
@ -426,3 +426,145 @@ class TestRunBuild:
assert set(result.built) == {"left.md", "right.md", "merge.txt"} assert set(result.built) == {"left.md", "right.md", "merge.txt"}
assert result.failed == {} assert result.failed == {}
class TestDownloadTarget:
"""Tests for download-type targets that fetch files from URLs."""
async def test_download_target_fetches_url(
self, project_dir: Path, write_config: WriteConfig
) -> None:
config = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish.png"},
}
}
)
mock_response = MagicMock()
mock_response.content = b"fake image bytes"
mock_response.raise_for_status = MagicMock()
with (
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await run_build(config, project_dir, _PROJECT)
assert result.built == ["fish.png"]
assert (project_dir / "fish.png").read_bytes() == b"fake image bytes"
mock_client.get.assert_called_once_with( # pyright: ignore[reportAny]
"https://example.com/fish.png", follow_redirects=True
)
async def test_download_target_incremental_skip(
self, project_dir: Path, write_config: WriteConfig
) -> None:
config = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish.png"},
}
}
)
mock_response = MagicMock()
mock_response.content = b"fake image bytes"
mock_response.raise_for_status = MagicMock()
with (
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
r1 = await run_build(config, project_dir, _PROJECT)
assert r1.built == ["fish.png"]
r2 = await run_build(config, project_dir, _PROJECT)
assert r2.skipped == ["fish.png"]
assert r2.built == []
async def test_download_target_rebuild_on_url_change(
self, project_dir: Path, write_config: WriteConfig
) -> None:
config1 = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish-v1.png"},
}
}
)
mock_response = MagicMock()
mock_response.content = b"v1 bytes"
mock_response.raise_for_status = MagicMock()
with (
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
r1 = await run_build(config1, project_dir, _PROJECT)
assert r1.built == ["fish.png"]
config2 = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish-v2.png"},
}
}
)
mock_response.content = b"v2 bytes"
r2 = await run_build(config2, project_dir, _PROJECT)
assert r2.built == ["fish.png"]
assert (project_dir / "fish.png").read_bytes() == b"v2 bytes"
async def test_download_target_as_dependency(
self, project_dir: Path, write_config: WriteConfig
) -> None:
config = write_config(
{
"targets": {
"fish.png": {"download": "https://example.com/fish.png"},
"description.txt": {
"prompt": "Describe the fish",
"inputs": ["fish.png"],
},
}
}
)
mock_response = MagicMock()
mock_response.content = b"fake fish image"
mock_response.raise_for_status = MagicMock()
with (
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
):
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await run_build(config, project_dir, _PROJECT)
assert "fish.png" in result.built
assert "description.txt" in result.built
assert (project_dir / "fish.png").read_bytes() == b"fake fish image"
assert (project_dir / "description.txt").exists()

View file

@ -7,7 +7,7 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from hokusai.config import load_config from hokusai.config import GenerateTargetConfig, load_config
class TestLoadConfig: class TestLoadConfig:
@ -21,7 +21,9 @@ class TestLoadConfig:
config = load_config(config_path) config = load_config(config_path)
assert "out.txt" in config.targets assert "out.txt" in config.targets
assert config.targets["out.txt"].prompt == "hello" target = config.targets["out.txt"]
assert isinstance(target, GenerateTargetConfig)
assert target.prompt == "hello"
assert config.defaults.text_model == "pixtral-large-latest" assert config.defaults.text_model == "pixtral-large-latest"
assert config.defaults.image_model == "flux-2-pro" assert config.defaults.image_model == "flux-2-pro"
@ -56,6 +58,7 @@ class TestLoadConfig:
assert config.defaults.image_model == "custom-image" assert config.defaults.image_model == "custom-image"
banner = config.targets["banner.png"] banner = config.targets["banner.png"]
assert isinstance(banner, GenerateTargetConfig)
assert banner.model == "flux-dev" assert banner.model == "flux-dev"
assert banner.width == 1920 assert banner.width == 1920
assert banner.height == 480 assert banner.height == 480
@ -63,6 +66,7 @@ class TestLoadConfig:
assert banner.control_images == ["ctrl.png"] assert banner.control_images == ["ctrl.png"]
story = config.targets["story.md"] story = config.targets["story.md"]
assert isinstance(story, GenerateTargetConfig)
assert story.model is None assert story.model is None
assert story.inputs == ["banner.png"] assert story.inputs == ["banner.png"]

View file

@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from hokusai.config import TargetConfig from hokusai.config import GenerateTargetConfig
from hokusai.providers.bfl import BFLResult from hokusai.providers.bfl import BFLResult
from hokusai.providers.blackforest import BlackForestProvider from hokusai.providers.blackforest import BlackForestProvider
from hokusai.providers.blackforest import ( from hokusai.providers.blackforest import (
@ -81,7 +81,7 @@ class TestBlackForestProvider:
async def test_basic_image_generation( async def test_basic_image_generation(
self, project_dir: Path, image_bytes: bytes self, project_dir: Path, image_bytes: bytes
) -> None: ) -> None:
target_config = TargetConfig(prompt="A red square") target_config = GenerateTargetConfig(prompt="A red square")
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
@ -107,7 +107,7 @@ class TestBlackForestProvider:
async def test_image_with_dimensions( async def test_image_with_dimensions(
self, project_dir: Path, image_bytes: bytes self, project_dir: Path, image_bytes: bytes
) -> None: ) -> None:
target_config = TargetConfig(prompt="A banner", width=1920, height=480) target_config = GenerateTargetConfig(prompt="A banner", width=1920, height=480)
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
@ -138,7 +138,9 @@ class TestBlackForestProvider:
ref_path = project_dir / "ref.png" ref_path = project_dir / "ref.png"
_ = ref_path.write_bytes(b"reference image data") _ = ref_path.write_bytes(b"reference image data")
target_config = TargetConfig(prompt="Like this", reference_images=["ref.png"]) target_config = GenerateTargetConfig(
prompt="Like this", reference_images=["ref.png"]
)
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
@ -171,7 +173,7 @@ class TestBlackForestProvider:
_ = ref1.write_bytes(b"ref1 data") _ = ref1.write_bytes(b"ref1 data")
_ = ref2.write_bytes(b"ref2 data") _ = ref2.write_bytes(b"ref2 data")
target_config = TargetConfig( target_config = GenerateTargetConfig(
prompt="Combine", reference_images=["ref1.png", "ref2.png"] prompt="Combine", reference_images=["ref1.png", "ref2.png"]
) )
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
@ -204,7 +206,9 @@ class TestBlackForestProvider:
ref_path = project_dir / "ref.png" ref_path = project_dir / "ref.png"
_ = ref_path.write_bytes(b"reference image data") _ = ref_path.write_bytes(b"reference image data")
target_config = TargetConfig(prompt="Edit", reference_images=["ref.png"]) target_config = GenerateTargetConfig(
prompt="Edit", reference_images=["ref.png"]
)
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
@ -229,7 +233,7 @@ class TestBlackForestProvider:
assert inputs["input_image"] == encode_image_b64(ref_path) assert inputs["input_image"] == encode_image_b64(ref_path)
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None: async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x") target_config = GenerateTargetConfig(prompt="x")
with patch("hokusai.providers.blackforest.BFLClient") as mock_cls: with patch("hokusai.providers.blackforest.BFLClient") as mock_cls:
from hokusai.providers.bfl import BFLError from hokusai.providers.bfl import BFLError
@ -261,7 +265,7 @@ class TestMistralProvider:
"""Test MistralProvider with mocked Mistral client.""" """Test MistralProvider with mocked Mistral client."""
async def test_basic_text_generation(self, project_dir: Path) -> None: async def test_basic_text_generation(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="Write a poem") target_config = GenerateTargetConfig(prompt="Write a poem")
response = _make_text_response("Roses are red...") response = _make_text_response("Roses are red...")
with patch("hokusai.providers.mistral.Mistral") as mock_cls: with patch("hokusai.providers.mistral.Mistral") as mock_cls:
@ -282,7 +286,7 @@ class TestMistralProvider:
async def test_text_with_text_input(self, project_dir: Path) -> None: async def test_text_with_text_input(self, project_dir: Path) -> None:
_ = (project_dir / "source.txt").write_text("Source material here") _ = (project_dir / "source.txt").write_text("Source material here")
target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"]) target_config = GenerateTargetConfig(prompt="Summarize", inputs=["source.txt"])
response = _make_text_response("Summary: ...") response = _make_text_response("Summary: ...")
with patch("hokusai.providers.mistral.Mistral") as mock_cls: with patch("hokusai.providers.mistral.Mistral") as mock_cls:
@ -306,7 +310,9 @@ class TestMistralProvider:
async def test_text_with_image_input(self, project_dir: Path) -> None: async def test_text_with_image_input(self, project_dir: Path) -> None:
_ = (project_dir / "photo.png").write_bytes(b"\x89PNG") _ = (project_dir / "photo.png").write_bytes(b"\x89PNG")
target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"]) target_config = GenerateTargetConfig(
prompt="Describe this image", inputs=["photo.png"]
)
response = _make_text_response("A beautiful photo") response = _make_text_response("A beautiful photo")
with patch("hokusai.providers.mistral.Mistral") as mock_cls: with patch("hokusai.providers.mistral.Mistral") as mock_cls:
@ -330,7 +336,7 @@ class TestMistralProvider:
assert chunks[1].image_url.url.startswith("data:image/png;base64,") assert chunks[1].image_url.url.startswith("data:image/png;base64,")
async def test_text_no_choices_raises(self, project_dir: Path) -> None: async def test_text_no_choices_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x") target_config = GenerateTargetConfig(prompt="x")
response = MagicMock() response = MagicMock()
response.choices = [] response.choices = []
@ -348,7 +354,7 @@ class TestMistralProvider:
) )
async def test_text_empty_content_raises(self, project_dir: Path) -> None: async def test_text_empty_content_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x") target_config = GenerateTargetConfig(prompt="x")
response = _make_text_response(None) response = _make_text_response(None)
with patch("hokusai.providers.mistral.Mistral") as mock_cls: with patch("hokusai.providers.mistral.Mistral") as mock_cls:
@ -369,7 +375,7 @@ class TestMistralProvider:
_ = (project_dir / "b.txt").write_text("content B") _ = (project_dir / "b.txt").write_text("content B")
_ = (project_dir / "c.png").write_bytes(b"\x89PNG") _ = (project_dir / "c.png").write_bytes(b"\x89PNG")
target_config = TargetConfig( target_config = GenerateTargetConfig(
prompt="Combine all", inputs=["a.txt", "b.txt", "c.png"] prompt="Combine all", inputs=["a.txt", "b.txt", "c.png"]
) )
response = _make_text_response("Combined") response = _make_text_response("Combined")
@ -400,7 +406,7 @@ class TestMistralProvider:
async def test_text_with_reference_images(self, project_dir: Path) -> None: async def test_text_with_reference_images(self, project_dir: Path) -> None:
_ = (project_dir / "ref.png").write_bytes(b"\x89PNG") _ = (project_dir / "ref.png").write_bytes(b"\x89PNG")
target_config = TargetConfig( target_config = GenerateTargetConfig(
prompt="Describe the style", reference_images=["ref.png"] prompt="Describe the style", reference_images=["ref.png"]
) )
response = _make_text_response("A stylized image") response = _make_text_response("A stylized image")
@ -455,7 +461,9 @@ class TestOpenAIImageProvider:
ref = project_dir / "ref.png" ref = project_dir / "ref.png"
_ = ref.write_bytes(b"reference data") _ = ref.write_bytes(b"reference data")
target_config = TargetConfig(prompt="Edit this", reference_images=["ref.png"]) target_config = GenerateTargetConfig(
prompt="Edit this", reference_images=["ref.png"]
)
b64 = base64.b64encode(image_bytes).decode() b64 = base64.b64encode(image_bytes).decode()
mock_client = _make_openai_mock(b64) mock_client = _make_openai_mock(b64)
@ -487,7 +495,7 @@ class TestOpenAIImageProvider:
_ = ref1.write_bytes(b"ref1 data") _ = ref1.write_bytes(b"ref1 data")
_ = ref2.write_bytes(b"ref2 data") _ = ref2.write_bytes(b"ref2 data")
target_config = TargetConfig( target_config = GenerateTargetConfig(
prompt="Combine", reference_images=["ref1.png", "ref2.png"] prompt="Combine", reference_images=["ref1.png", "ref2.png"]
) )
b64 = base64.b64encode(image_bytes).decode() b64 = base64.b64encode(image_bytes).decode()

View file

@ -6,7 +6,7 @@ import pytest
from hokusai.config import ( from hokusai.config import (
Defaults, Defaults,
TargetConfig, GenerateTargetConfig,
) )
from hokusai.providers.models import Capability from hokusai.providers.models import Capability
from hokusai.resolve import infer_required_capabilities, resolve_model from hokusai.resolve import infer_required_capabilities, resolve_model
@ -16,37 +16,37 @@ class TestInferRequiredCapabilities:
"""Test capability inference from file extensions and target config.""" """Test capability inference from file extensions and target config."""
def test_plain_image(self) -> None: def test_plain_image(self) -> None:
target = TargetConfig(prompt="x") target = GenerateTargetConfig(prompt="x")
assert infer_required_capabilities("out.png", target) == frozenset( assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE} {Capability.TEXT_TO_IMAGE}
) )
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"]) @pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
def test_image_extensions(self, name: str) -> None: def test_image_extensions(self, name: str) -> None:
target = TargetConfig(prompt="x") target = GenerateTargetConfig(prompt="x")
caps = infer_required_capabilities(name, target) caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps assert Capability.TEXT_TO_IMAGE in caps
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"]) @pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
def test_case_insensitive(self, name: str) -> None: def test_case_insensitive(self, name: str) -> None:
target = TargetConfig(prompt="x") target = GenerateTargetConfig(prompt="x")
caps = infer_required_capabilities(name, target) caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps assert Capability.TEXT_TO_IMAGE in caps
def test_image_with_reference_images(self) -> None: def test_image_with_reference_images(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.png"]) target = GenerateTargetConfig(prompt="x", reference_images=["ref.png"])
assert infer_required_capabilities("out.png", target) == frozenset( assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES} {Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
) )
def test_image_with_control_images(self) -> None: def test_image_with_control_images(self) -> None:
target = TargetConfig(prompt="x", control_images=["ctrl.png"]) target = GenerateTargetConfig(prompt="x", control_images=["ctrl.png"])
assert infer_required_capabilities("out.png", target) == frozenset( assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES} {Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
) )
def test_image_with_both(self) -> None: def test_image_with_both(self) -> None:
target = TargetConfig( target = GenerateTargetConfig(
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"] prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
) )
assert infer_required_capabilities("out.png", target) == frozenset( assert infer_required_capabilities("out.png", target) == frozenset(
@ -59,35 +59,35 @@ class TestInferRequiredCapabilities:
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"]) @pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
def test_text_extensions(self, name: str) -> None: def test_text_extensions(self, name: str) -> None:
target = TargetConfig(prompt="x") target = GenerateTargetConfig(prompt="x")
caps = infer_required_capabilities(name, target) caps = infer_required_capabilities(name, target)
assert caps == frozenset({Capability.TEXT_GENERATION}) assert caps == frozenset({Capability.TEXT_GENERATION})
def test_text_with_text_inputs(self) -> None: def test_text_with_text_inputs(self) -> None:
target = TargetConfig(prompt="x", inputs=["data.txt"]) target = GenerateTargetConfig(prompt="x", inputs=["data.txt"])
assert infer_required_capabilities("out.md", target) == frozenset( assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION} {Capability.TEXT_GENERATION}
) )
def test_text_with_image_input(self) -> None: def test_text_with_image_input(self) -> None:
target = TargetConfig(prompt="x", inputs=["photo.png"]) target = GenerateTargetConfig(prompt="x", inputs=["photo.png"])
assert infer_required_capabilities("out.txt", target) == frozenset( assert infer_required_capabilities("out.txt", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION} {Capability.TEXT_GENERATION, Capability.VISION}
) )
def test_text_with_image_reference(self) -> None: def test_text_with_image_reference(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.jpg"]) target = GenerateTargetConfig(prompt="x", reference_images=["ref.jpg"])
assert infer_required_capabilities("out.md", target) == frozenset( assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION} {Capability.TEXT_GENERATION, Capability.VISION}
) )
def test_unsupported_extension_raises(self) -> None: def test_unsupported_extension_raises(self) -> None:
target = TargetConfig(prompt="x") target = GenerateTargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"): with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_required_capabilities("data.csv", target) _ = infer_required_capabilities("data.csv", target)
def test_no_extension_raises(self) -> None: def test_no_extension_raises(self) -> None:
target = TargetConfig(prompt="x") target = GenerateTargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"): with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_required_capabilities("Makefile", target) _ = infer_required_capabilities("Makefile", target)
@ -96,50 +96,52 @@ class TestResolveModel:
"""Test model resolution with capability validation.""" """Test model resolution with capability validation."""
def test_explicit_model_wins(self) -> None: def test_explicit_model_wins(self) -> None:
target = TargetConfig(prompt="x", model="mistral-small-latest") target = GenerateTargetConfig(prompt="x", model="mistral-small-latest")
result = resolve_model("out.txt", target, Defaults()) result = resolve_model("out.txt", target, Defaults())
assert result.name == "mistral-small-latest" assert result.name == "mistral-small-latest"
def test_default_text_model(self) -> None: def test_default_text_model(self) -> None:
target = TargetConfig(prompt="x") target = GenerateTargetConfig(prompt="x")
defaults = Defaults(text_model="mistral-large-latest") defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.md", target, defaults) result = resolve_model("out.md", target, defaults)
assert result.name == "mistral-large-latest" assert result.name == "mistral-large-latest"
def test_default_image_model(self) -> None: def test_default_image_model(self) -> None:
target = TargetConfig(prompt="x") target = GenerateTargetConfig(prompt="x")
defaults = Defaults(image_model="flux-dev") defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults) result = resolve_model("out.png", target, defaults)
assert result.name == "flux-dev" assert result.name == "flux-dev"
def test_unknown_model_raises(self) -> None: def test_unknown_model_raises(self) -> None:
target = TargetConfig(prompt="x", model="nonexistent-model") target = GenerateTargetConfig(prompt="x", model="nonexistent-model")
with pytest.raises(ValueError, match="Unknown model"): with pytest.raises(ValueError, match="Unknown model"):
_ = resolve_model("out.txt", target, Defaults()) _ = resolve_model("out.txt", target, Defaults())
def test_explicit_model_missing_capability_raises(self) -> None: def test_explicit_model_missing_capability_raises(self) -> None:
# flux-dev does not support reference images # flux-dev does not support reference images
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"]) target = GenerateTargetConfig(
prompt="x", model="flux-dev", reference_images=["r.png"]
)
with pytest.raises(ValueError, match="lacks required capabilities"): with pytest.raises(ValueError, match="lacks required capabilities"):
_ = resolve_model("out.png", target, Defaults()) _ = resolve_model("out.png", target, Defaults())
def test_default_fallback_for_reference_images(self) -> None: def test_default_fallback_for_reference_images(self) -> None:
# Default flux-dev lacks reference_images, should fall back to a capable model # Default flux-dev lacks reference_images, should fall back to a capable model
target = TargetConfig(prompt="x", reference_images=["r.png"]) target = GenerateTargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-dev") defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults) result = resolve_model("out.png", target, defaults)
assert Capability.REFERENCE_IMAGES in result.capabilities assert Capability.REFERENCE_IMAGES in result.capabilities
def test_default_fallback_for_vision(self) -> None: def test_default_fallback_for_vision(self) -> None:
# Default mistral-large-latest lacks vision, should fall back to a pixtral model # Default mistral-large-latest lacks vision, should fall back to a pixtral model
target = TargetConfig(prompt="x", inputs=["photo.png"]) target = GenerateTargetConfig(prompt="x", inputs=["photo.png"])
defaults = Defaults(text_model="mistral-large-latest") defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.txt", target, defaults) result = resolve_model("out.txt", target, defaults)
assert Capability.VISION in result.capabilities assert Capability.VISION in result.capabilities
def test_default_preferred_when_capable(self) -> None: def test_default_preferred_when_capable(self) -> None:
# Default flux-2-pro already supports reference_images, should be used directly # Default flux-2-pro already supports reference_images, should be used directly
target = TargetConfig(prompt="x", reference_images=["r.png"]) target = GenerateTargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-2-pro") defaults = Defaults(image_model="flux-2-pro")
result = resolve_model("out.png", target, defaults) result = resolve_model("out.png", target, defaults)
assert result.name == "flux-2-pro" assert result.name == "flux-2-pro"