feat: add download target type for fetching files from URLs
This commit is contained in:
parent
a4600df4d5
commit
c1ad6e6e3c
14 changed files with 296 additions and 74 deletions
|
|
@ -9,7 +9,9 @@ from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from hokusai.config import ProjectConfig
|
import httpx
|
||||||
|
|
||||||
|
from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig
|
||||||
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
|
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
|
||||||
from hokusai.providers import Provider
|
from hokusai.providers import Provider
|
||||||
from hokusai.providers.blackforest import BlackForestProvider
|
from hokusai.providers.blackforest import BlackForestProvider
|
||||||
|
|
@ -73,6 +75,8 @@ def _collect_dep_files(
|
||||||
) -> list[Path]:
|
) -> list[Path]:
|
||||||
"""Collect all dependency file paths for a target."""
|
"""Collect all dependency file paths for a target."""
|
||||||
target_cfg = config.targets[target_name]
|
target_cfg = config.targets[target_name]
|
||||||
|
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||||
|
return []
|
||||||
deps: list[str] = list(target_cfg.inputs)
|
deps: list[str] = list(target_cfg.inputs)
|
||||||
deps.extend(target_cfg.reference_images)
|
deps.extend(target_cfg.reference_images)
|
||||||
deps.extend(target_cfg.control_images)
|
deps.extend(target_cfg.control_images)
|
||||||
|
|
@ -82,6 +86,8 @@ def _collect_dep_files(
|
||||||
def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]:
|
def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]:
|
||||||
"""Collect extra parameters that affect rebuild decisions."""
|
"""Collect extra parameters that affect rebuild decisions."""
|
||||||
target_cfg = config.targets[target_name]
|
target_cfg = config.targets[target_name]
|
||||||
|
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||||
|
return {}
|
||||||
params: dict[str, object] = {}
|
params: dict[str, object] = {}
|
||||||
if target_cfg.width is not None:
|
if target_cfg.width is not None:
|
||||||
params["width"] = target_cfg.width
|
params["width"] = target_cfg.width
|
||||||
|
|
@ -97,6 +103,8 @@ def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str,
|
||||||
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
|
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
|
||||||
"""Collect all dependency names (inputs + reference_images + control_images)."""
|
"""Collect all dependency names (inputs + reference_images + control_images)."""
|
||||||
target_cfg = config.targets[target_name]
|
target_cfg = config.targets[target_name]
|
||||||
|
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||||
|
return []
|
||||||
deps: list[str] = list(target_cfg.inputs)
|
deps: list[str] = list(target_cfg.inputs)
|
||||||
deps.extend(target_cfg.reference_images)
|
deps.extend(target_cfg.reference_images)
|
||||||
deps.extend(target_cfg.control_images)
|
deps.extend(target_cfg.control_images)
|
||||||
|
|
@ -128,6 +136,18 @@ def _build_provider_index(providers: list[Provider]) -> dict[str, Provider]:
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_target(
|
||||||
|
target_name: str,
|
||||||
|
target_cfg: DownloadTargetConfig,
|
||||||
|
project_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Download a file from a URL and write it to *project_dir / target_name*."""
|
||||||
|
async with httpx.AsyncClient(headers={"User-Agent": "hokusai"}) as client:
|
||||||
|
response = await client.get(target_cfg.download, follow_redirects=True)
|
||||||
|
_ = response.raise_for_status()
|
||||||
|
_ = (project_dir / target_name).write_bytes(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def _build_single_target(
|
async def _build_single_target(
|
||||||
target_name: str,
|
target_name: str,
|
||||||
config: ProjectConfig,
|
config: ProjectConfig,
|
||||||
|
|
@ -136,6 +156,11 @@ async def _build_single_target(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Build a single target by dispatching to the appropriate provider."""
|
"""Build a single target by dispatching to the appropriate provider."""
|
||||||
target_cfg = config.targets[target_name]
|
target_cfg = config.targets[target_name]
|
||||||
|
|
||||||
|
if isinstance(target_cfg, DownloadTargetConfig):
|
||||||
|
await _download_target(target_name, target_cfg, project_dir)
|
||||||
|
return
|
||||||
|
|
||||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||||
|
|
||||||
|
|
@ -227,6 +252,9 @@ def _should_skip_failed_dep(
|
||||||
return any(d in result.failed for d in _collect_all_deps(target_name, config))
|
return any(d in result.failed for d in _collect_all_deps(target_name, config))
|
||||||
|
|
||||||
|
|
||||||
|
_DOWNLOAD_MODEL_SENTINEL = "__download__"
|
||||||
|
|
||||||
|
|
||||||
def _is_dirty(
|
def _is_dirty(
|
||||||
target_name: str,
|
target_name: str,
|
||||||
config: ProjectConfig,
|
config: ProjectConfig,
|
||||||
|
|
@ -235,6 +263,18 @@ def _is_dirty(
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if a target needs rebuilding."""
|
"""Check if a target needs rebuilding."""
|
||||||
target_cfg = config.targets[target_name]
|
target_cfg = config.targets[target_name]
|
||||||
|
|
||||||
|
if isinstance(target_cfg, DownloadTargetConfig):
|
||||||
|
return is_target_dirty(
|
||||||
|
target_name,
|
||||||
|
resolved_prompt=target_cfg.download,
|
||||||
|
model=_DOWNLOAD_MODEL_SENTINEL,
|
||||||
|
dep_files=[],
|
||||||
|
extra_params={},
|
||||||
|
state=state,
|
||||||
|
project_dir=project_dir,
|
||||||
|
)
|
||||||
|
|
||||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||||
dep_files = _collect_dep_files(target_name, config, project_dir)
|
dep_files = _collect_dep_files(target_name, config, project_dir)
|
||||||
|
|
@ -260,6 +300,8 @@ def _has_provider(
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check that the required provider is available; record failure if not."""
|
"""Check that the required provider is available; record failure if not."""
|
||||||
target_cfg = config.targets[target_name]
|
target_cfg = config.targets[target_name]
|
||||||
|
if isinstance(target_cfg, DownloadTargetConfig):
|
||||||
|
return True
|
||||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||||
if model_info.name not in provider_index:
|
if model_info.name not in provider_index:
|
||||||
msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables"
|
msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables"
|
||||||
|
|
@ -302,15 +344,22 @@ def _process_outcomes(
|
||||||
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
|
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
|
||||||
else:
|
else:
|
||||||
target_cfg = config.targets[name]
|
target_cfg = config.targets[name]
|
||||||
|
|
||||||
|
if isinstance(target_cfg, DownloadTargetConfig):
|
||||||
|
resolved_prompt = target_cfg.download
|
||||||
|
model_name = _DOWNLOAD_MODEL_SENTINEL
|
||||||
|
else:
|
||||||
model_info = resolve_model(name, target_cfg, config.defaults)
|
model_info = resolve_model(name, target_cfg, config.defaults)
|
||||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||||
|
model_name = model_info.name
|
||||||
|
|
||||||
dep_files = _collect_dep_files(name, config, project_dir)
|
dep_files = _collect_dep_files(name, config, project_dir)
|
||||||
extra = _collect_extra_params(name, config)
|
extra = _collect_extra_params(name, config)
|
||||||
|
|
||||||
record_target_state(
|
record_target_state(
|
||||||
name,
|
name,
|
||||||
resolved_prompt=resolved_prompt,
|
resolved_prompt=resolved_prompt,
|
||||||
model=model_info.name,
|
model=model_name,
|
||||||
dep_files=dep_files,
|
dep_files=dep_files,
|
||||||
extra_params=extra,
|
extra_params=extra,
|
||||||
state=state,
|
state=state,
|
||||||
|
|
|
||||||
|
|
@ -257,18 +257,12 @@ def init() -> None:
|
||||||
|
|
||||||
content = f"""\
|
content = f"""\
|
||||||
# {name} - hokusai project
|
# {name} - hokusai project
|
||||||
defaults:
|
|
||||||
image_model: flux-2-pro
|
|
||||||
|
|
||||||
targets:
|
targets:
|
||||||
great_wave.png:
|
great_wave.png:
|
||||||
prompt: >-
|
prompt: >-
|
||||||
A recreation of Hokusai's "The Great Wave off Kanagawa", but instead of
|
A recreation of Hokusai's "The Great Wave off Kanagawa", but instead of
|
||||||
boats and people, paint brushes, canvases, and framed paintings are
|
boats and people, paint brushes, canvases, and framed paintings are
|
||||||
swimming and tumbling in the towering wave. Oil paint tubes burst open
|
swimming and tumbling in the towering wave.
|
||||||
and trail ribbons of colour through the spray. The iconic Mount Fuji
|
|
||||||
sits serenely in the background. Ukiyo-e woodblock print style with
|
|
||||||
vivid modern pigment colours.
|
|
||||||
"""
|
"""
|
||||||
_ = dest.write_text(content)
|
_ = dest.write_text(content)
|
||||||
click.echo(click.style(" created ", fg="green") + click.style(filename, bold=True))
|
click.echo(click.style(" created ", fg="green") + click.style(filename, bold=True))
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Self
|
from typing import Annotated, Self
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, Discriminator, Tag, model_validator
|
||||||
|
|
||||||
from hokusai.providers.models import Capability
|
from hokusai.providers.models import Capability
|
||||||
|
|
||||||
|
|
@ -29,8 +29,8 @@ class Defaults(BaseModel):
|
||||||
image_model: str = "flux-2-pro"
|
image_model: str = "flux-2-pro"
|
||||||
|
|
||||||
|
|
||||||
class TargetConfig(BaseModel):
|
class GenerateTargetConfig(BaseModel):
|
||||||
"""Configuration for a single build target."""
|
"""Configuration for a target that generates an artifact via an AI provider."""
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
|
|
@ -41,6 +41,26 @@ class TargetConfig(BaseModel):
|
||||||
height: int | None = None
|
height: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadTargetConfig(BaseModel):
|
||||||
|
"""Configuration for a target that downloads a file from a URL."""
|
||||||
|
|
||||||
|
download: str
|
||||||
|
|
||||||
|
|
||||||
|
def _target_discriminator(raw: object) -> str:
|
||||||
|
"""Discriminate between generate and download target configs."""
|
||||||
|
if isinstance(raw, dict) and "download" in raw:
|
||||||
|
return "download"
|
||||||
|
return "generate"
|
||||||
|
|
||||||
|
|
||||||
|
TargetConfig = Annotated[
|
||||||
|
Annotated[GenerateTargetConfig, Tag("generate")]
|
||||||
|
| Annotated[DownloadTargetConfig, Tag("download")],
|
||||||
|
Discriminator(_target_discriminator),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ProjectConfig(BaseModel):
|
class ProjectConfig(BaseModel):
|
||||||
"""Top-level configuration parsed from ``<name>.hokusai.yaml``."""
|
"""Top-level configuration parsed from ``<name>.hokusai.yaml``."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
from hokusai.config import ProjectConfig
|
from hokusai.config import GenerateTargetConfig, ProjectConfig
|
||||||
|
|
||||||
|
|
||||||
def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
|
def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
|
||||||
|
|
@ -25,6 +25,9 @@ def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
|
||||||
for target_name, target_cfg in config.targets.items():
|
for target_name, target_cfg in config.targets.items():
|
||||||
graph.add_node(target_name)
|
graph.add_node(target_name)
|
||||||
|
|
||||||
|
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||||
|
continue
|
||||||
|
|
||||||
deps: list[str] = list(target_cfg.inputs)
|
deps: list[str] = list(target_cfg.inputs)
|
||||||
deps.extend(target_cfg.reference_images)
|
deps.extend(target_cfg.reference_images)
|
||||||
deps.extend(target_cfg.control_images)
|
deps.extend(target_cfg.control_images)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
|
||||||
from hokusai.providers.models import ModelInfo
|
from hokusai.providers.models import ModelInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from hokusai.config import TargetConfig
|
from hokusai.config import GenerateTargetConfig
|
||||||
|
|
||||||
|
|
||||||
class Provider(abc.ABC):
|
class Provider(abc.ABC):
|
||||||
|
|
@ -24,7 +24,7 @@ class Provider(abc.ABC):
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import override
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from hokusai.config import TargetConfig
|
from hokusai.config import GenerateTargetConfig
|
||||||
from hokusai.providers import Provider
|
from hokusai.providers import Provider
|
||||||
from hokusai.providers.bfl import BFLClient
|
from hokusai.providers.bfl import BFLClient
|
||||||
from hokusai.providers.models import Capability, ModelInfo
|
from hokusai.providers.models import Capability, ModelInfo
|
||||||
|
|
@ -123,7 +123,7 @@ class BlackForestProvider(Provider):
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing import override
|
||||||
|
|
||||||
from mistralai import Mistral, models
|
from mistralai import Mistral, models
|
||||||
|
|
||||||
from hokusai.config import IMAGE_EXTENSIONS, TargetConfig
|
from hokusai.config import IMAGE_EXTENSIONS, GenerateTargetConfig
|
||||||
from hokusai.providers import Provider
|
from hokusai.providers import Provider
|
||||||
from hokusai.providers.models import Capability, ModelInfo
|
from hokusai.providers.models import Capability, ModelInfo
|
||||||
|
|
||||||
|
|
@ -63,7 +63,7 @@ class MistralProvider(Provider):
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import httpx
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.images_response import ImagesResponse
|
from openai.types.images_response import ImagesResponse
|
||||||
|
|
||||||
from hokusai.config import TargetConfig
|
from hokusai.config import GenerateTargetConfig
|
||||||
from hokusai.providers import Provider
|
from hokusai.providers import Provider
|
||||||
from hokusai.providers.models import Capability, ModelInfo
|
from hokusai.providers.models import Capability, ModelInfo
|
||||||
|
|
||||||
|
|
@ -109,7 +109,7 @@ class OpenAIImageProvider(Provider):
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from openai.types.chat import (
|
||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
from hokusai.config import IMAGE_EXTENSIONS, TargetConfig
|
from hokusai.config import IMAGE_EXTENSIONS, GenerateTargetConfig
|
||||||
from hokusai.providers import Provider
|
from hokusai.providers import Provider
|
||||||
from hokusai.providers.models import Capability, ModelInfo
|
from hokusai.providers.models import Capability, ModelInfo
|
||||||
|
|
||||||
|
|
@ -120,7 +120,7 @@ class OpenAITextProvider(Provider):
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from hokusai.config import (
|
||||||
IMAGE_EXTENSIONS,
|
IMAGE_EXTENSIONS,
|
||||||
TEXT_EXTENSIONS,
|
TEXT_EXTENSIONS,
|
||||||
Defaults,
|
Defaults,
|
||||||
TargetConfig,
|
GenerateTargetConfig,
|
||||||
TargetType,
|
TargetType,
|
||||||
target_type_from_capabilities,
|
target_type_from_capabilities,
|
||||||
)
|
)
|
||||||
|
|
@ -14,7 +14,7 @@ from hokusai.providers.models import Capability, ModelInfo
|
||||||
|
|
||||||
|
|
||||||
def infer_required_capabilities(
|
def infer_required_capabilities(
|
||||||
target_name: str, target: TargetConfig
|
target_name: str, target: GenerateTargetConfig
|
||||||
) -> frozenset[Capability]:
|
) -> frozenset[Capability]:
|
||||||
"""Infer the capabilities a model must have based on filename and config.
|
"""Infer the capabilities a model must have based on filename and config.
|
||||||
|
|
||||||
|
|
@ -42,7 +42,7 @@ def infer_required_capabilities(
|
||||||
|
|
||||||
|
|
||||||
def resolve_model(
|
def resolve_model(
|
||||||
target_name: str, target: TargetConfig, defaults: Defaults
|
target_name: str, target: GenerateTargetConfig, defaults: Defaults
|
||||||
) -> ModelInfo:
|
) -> ModelInfo:
|
||||||
"""Return the effective model for a target, validated against required capabilities.
|
"""Return the effective model for a target, validated against required capabilities.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import override
|
from typing import override
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -16,7 +16,7 @@ from hokusai.builder import (
|
||||||
_resolve_prompt, # pyright: ignore[reportPrivateUsage]
|
_resolve_prompt, # pyright: ignore[reportPrivateUsage]
|
||||||
run_build,
|
run_build,
|
||||||
)
|
)
|
||||||
from hokusai.config import ProjectConfig, TargetConfig
|
from hokusai.config import GenerateTargetConfig, ProjectConfig
|
||||||
from hokusai.providers import Provider
|
from hokusai.providers import Provider
|
||||||
from hokusai.providers.models import Capability, ModelInfo
|
from hokusai.providers.models import Capability, ModelInfo
|
||||||
from hokusai.state import load_state
|
from hokusai.state import load_state
|
||||||
|
|
@ -57,7 +57,7 @@ class FakeTextProvider(Provider):
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
@ -78,7 +78,7 @@ class FakeImageProvider(Provider):
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
@ -99,7 +99,7 @@ class FailingTextProvider(Provider):
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
@ -295,7 +295,7 @@ class TestRunBuild:
|
||||||
|
|
||||||
async def selective_generate(
|
async def selective_generate(
|
||||||
target_name: str,
|
target_name: str,
|
||||||
target_config: TargetConfig,
|
target_config: GenerateTargetConfig,
|
||||||
resolved_prompt: str,
|
resolved_prompt: str,
|
||||||
resolved_model: ModelInfo,
|
resolved_model: ModelInfo,
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
|
|
@ -426,3 +426,145 @@ class TestRunBuild:
|
||||||
|
|
||||||
assert set(result.built) == {"left.md", "right.md", "merge.txt"}
|
assert set(result.built) == {"left.md", "right.md", "merge.txt"}
|
||||||
assert result.failed == {}
|
assert result.failed == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownloadTarget:
|
||||||
|
"""Tests for download-type targets that fetch files from URLs."""
|
||||||
|
|
||||||
|
async def test_download_target_fetches_url(
|
||||||
|
self, project_dir: Path, write_config: WriteConfig
|
||||||
|
) -> None:
|
||||||
|
config = write_config(
|
||||||
|
{
|
||||||
|
"targets": {
|
||||||
|
"fish.png": {"download": "https://example.com/fish.png"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = b"fake image bytes"
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
|
||||||
|
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
|
||||||
|
):
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
|
||||||
|
result = await run_build(config, project_dir, _PROJECT)
|
||||||
|
|
||||||
|
assert result.built == ["fish.png"]
|
||||||
|
assert (project_dir / "fish.png").read_bytes() == b"fake image bytes"
|
||||||
|
mock_client.get.assert_called_once_with( # pyright: ignore[reportAny]
|
||||||
|
"https://example.com/fish.png", follow_redirects=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_download_target_incremental_skip(
|
||||||
|
self, project_dir: Path, write_config: WriteConfig
|
||||||
|
) -> None:
|
||||||
|
config = write_config(
|
||||||
|
{
|
||||||
|
"targets": {
|
||||||
|
"fish.png": {"download": "https://example.com/fish.png"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = b"fake image bytes"
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
|
||||||
|
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
|
||||||
|
):
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
|
||||||
|
r1 = await run_build(config, project_dir, _PROJECT)
|
||||||
|
assert r1.built == ["fish.png"]
|
||||||
|
|
||||||
|
r2 = await run_build(config, project_dir, _PROJECT)
|
||||||
|
assert r2.skipped == ["fish.png"]
|
||||||
|
assert r2.built == []
|
||||||
|
|
||||||
|
async def test_download_target_rebuild_on_url_change(
|
||||||
|
self, project_dir: Path, write_config: WriteConfig
|
||||||
|
) -> None:
|
||||||
|
config1 = write_config(
|
||||||
|
{
|
||||||
|
"targets": {
|
||||||
|
"fish.png": {"download": "https://example.com/fish-v1.png"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = b"v1 bytes"
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
|
||||||
|
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
|
||||||
|
):
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
|
||||||
|
r1 = await run_build(config1, project_dir, _PROJECT)
|
||||||
|
assert r1.built == ["fish.png"]
|
||||||
|
|
||||||
|
config2 = write_config(
|
||||||
|
{
|
||||||
|
"targets": {
|
||||||
|
"fish.png": {"download": "https://example.com/fish-v2.png"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.content = b"v2 bytes"
|
||||||
|
|
||||||
|
r2 = await run_build(config2, project_dir, _PROJECT)
|
||||||
|
assert r2.built == ["fish.png"]
|
||||||
|
assert (project_dir / "fish.png").read_bytes() == b"v2 bytes"
|
||||||
|
|
||||||
|
async def test_download_target_as_dependency(
|
||||||
|
self, project_dir: Path, write_config: WriteConfig
|
||||||
|
) -> None:
|
||||||
|
config = write_config(
|
||||||
|
{
|
||||||
|
"targets": {
|
||||||
|
"fish.png": {"download": "https://example.com/fish.png"},
|
||||||
|
"description.txt": {
|
||||||
|
"prompt": "Describe the fish",
|
||||||
|
"inputs": ["fish.png"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = b"fake fish image"
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("hokusai.builder._create_providers", return_value=_fake_providers()),
|
||||||
|
patch("hokusai.builder.httpx.AsyncClient") as mock_client_cls,
|
||||||
|
):
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
|
||||||
|
result = await run_build(config, project_dir, _PROJECT)
|
||||||
|
|
||||||
|
assert "fish.png" in result.built
|
||||||
|
assert "description.txt" in result.built
|
||||||
|
assert (project_dir / "fish.png").read_bytes() == b"fake fish image"
|
||||||
|
assert (project_dir / "description.txt").exists()
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from hokusai.config import load_config
|
from hokusai.config import GenerateTargetConfig, load_config
|
||||||
|
|
||||||
|
|
||||||
class TestLoadConfig:
|
class TestLoadConfig:
|
||||||
|
|
@ -21,7 +21,9 @@ class TestLoadConfig:
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
|
|
||||||
assert "out.txt" in config.targets
|
assert "out.txt" in config.targets
|
||||||
assert config.targets["out.txt"].prompt == "hello"
|
target = config.targets["out.txt"]
|
||||||
|
assert isinstance(target, GenerateTargetConfig)
|
||||||
|
assert target.prompt == "hello"
|
||||||
assert config.defaults.text_model == "pixtral-large-latest"
|
assert config.defaults.text_model == "pixtral-large-latest"
|
||||||
assert config.defaults.image_model == "flux-2-pro"
|
assert config.defaults.image_model == "flux-2-pro"
|
||||||
|
|
||||||
|
|
@ -56,6 +58,7 @@ class TestLoadConfig:
|
||||||
assert config.defaults.image_model == "custom-image"
|
assert config.defaults.image_model == "custom-image"
|
||||||
|
|
||||||
banner = config.targets["banner.png"]
|
banner = config.targets["banner.png"]
|
||||||
|
assert isinstance(banner, GenerateTargetConfig)
|
||||||
assert banner.model == "flux-dev"
|
assert banner.model == "flux-dev"
|
||||||
assert banner.width == 1920
|
assert banner.width == 1920
|
||||||
assert banner.height == 480
|
assert banner.height == 480
|
||||||
|
|
@ -63,6 +66,7 @@ class TestLoadConfig:
|
||||||
assert banner.control_images == ["ctrl.png"]
|
assert banner.control_images == ["ctrl.png"]
|
||||||
|
|
||||||
story = config.targets["story.md"]
|
story = config.targets["story.md"]
|
||||||
|
assert isinstance(story, GenerateTargetConfig)
|
||||||
assert story.model is None
|
assert story.model is None
|
||||||
assert story.inputs == ["banner.png"]
|
assert story.inputs == ["banner.png"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from hokusai.config import TargetConfig
|
from hokusai.config import GenerateTargetConfig
|
||||||
from hokusai.providers.bfl import BFLResult
|
from hokusai.providers.bfl import BFLResult
|
||||||
from hokusai.providers.blackforest import BlackForestProvider
|
from hokusai.providers.blackforest import BlackForestProvider
|
||||||
from hokusai.providers.blackforest import (
|
from hokusai.providers.blackforest import (
|
||||||
|
|
@ -81,7 +81,7 @@ class TestBlackForestProvider:
|
||||||
async def test_basic_image_generation(
|
async def test_basic_image_generation(
|
||||||
self, project_dir: Path, image_bytes: bytes
|
self, project_dir: Path, image_bytes: bytes
|
||||||
) -> None:
|
) -> None:
|
||||||
target_config = TargetConfig(prompt="A red square")
|
target_config = GenerateTargetConfig(prompt="A red square")
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -107,7 +107,7 @@ class TestBlackForestProvider:
|
||||||
async def test_image_with_dimensions(
|
async def test_image_with_dimensions(
|
||||||
self, project_dir: Path, image_bytes: bytes
|
self, project_dir: Path, image_bytes: bytes
|
||||||
) -> None:
|
) -> None:
|
||||||
target_config = TargetConfig(prompt="A banner", width=1920, height=480)
|
target_config = GenerateTargetConfig(prompt="A banner", width=1920, height=480)
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -138,7 +138,9 @@ class TestBlackForestProvider:
|
||||||
ref_path = project_dir / "ref.png"
|
ref_path = project_dir / "ref.png"
|
||||||
_ = ref_path.write_bytes(b"reference image data")
|
_ = ref_path.write_bytes(b"reference image data")
|
||||||
|
|
||||||
target_config = TargetConfig(prompt="Like this", reference_images=["ref.png"])
|
target_config = GenerateTargetConfig(
|
||||||
|
prompt="Like this", reference_images=["ref.png"]
|
||||||
|
)
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -171,7 +173,7 @@ class TestBlackForestProvider:
|
||||||
_ = ref1.write_bytes(b"ref1 data")
|
_ = ref1.write_bytes(b"ref1 data")
|
||||||
_ = ref2.write_bytes(b"ref2 data")
|
_ = ref2.write_bytes(b"ref2 data")
|
||||||
|
|
||||||
target_config = TargetConfig(
|
target_config = GenerateTargetConfig(
|
||||||
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
|
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
|
||||||
)
|
)
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
@ -204,7 +206,9 @@ class TestBlackForestProvider:
|
||||||
ref_path = project_dir / "ref.png"
|
ref_path = project_dir / "ref.png"
|
||||||
_ = ref_path.write_bytes(b"reference image data")
|
_ = ref_path.write_bytes(b"reference image data")
|
||||||
|
|
||||||
target_config = TargetConfig(prompt="Edit", reference_images=["ref.png"])
|
target_config = GenerateTargetConfig(
|
||||||
|
prompt="Edit", reference_images=["ref.png"]
|
||||||
|
)
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
|
@ -229,7 +233,7 @@ class TestBlackForestProvider:
|
||||||
assert inputs["input_image"] == encode_image_b64(ref_path)
|
assert inputs["input_image"] == encode_image_b64(ref_path)
|
||||||
|
|
||||||
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
|
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
|
||||||
target_config = TargetConfig(prompt="x")
|
target_config = GenerateTargetConfig(prompt="x")
|
||||||
|
|
||||||
with patch("hokusai.providers.blackforest.BFLClient") as mock_cls:
|
with patch("hokusai.providers.blackforest.BFLClient") as mock_cls:
|
||||||
from hokusai.providers.bfl import BFLError
|
from hokusai.providers.bfl import BFLError
|
||||||
|
|
@ -261,7 +265,7 @@ class TestMistralProvider:
|
||||||
"""Test MistralProvider with mocked Mistral client."""
|
"""Test MistralProvider with mocked Mistral client."""
|
||||||
|
|
||||||
async def test_basic_text_generation(self, project_dir: Path) -> None:
|
async def test_basic_text_generation(self, project_dir: Path) -> None:
|
||||||
target_config = TargetConfig(prompt="Write a poem")
|
target_config = GenerateTargetConfig(prompt="Write a poem")
|
||||||
response = _make_text_response("Roses are red...")
|
response = _make_text_response("Roses are red...")
|
||||||
|
|
||||||
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
|
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
|
||||||
|
|
@ -282,7 +286,7 @@ class TestMistralProvider:
|
||||||
|
|
||||||
async def test_text_with_text_input(self, project_dir: Path) -> None:
|
async def test_text_with_text_input(self, project_dir: Path) -> None:
|
||||||
_ = (project_dir / "source.txt").write_text("Source material here")
|
_ = (project_dir / "source.txt").write_text("Source material here")
|
||||||
target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"])
|
target_config = GenerateTargetConfig(prompt="Summarize", inputs=["source.txt"])
|
||||||
response = _make_text_response("Summary: ...")
|
response = _make_text_response("Summary: ...")
|
||||||
|
|
||||||
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
|
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
|
||||||
|
|
@ -306,7 +310,9 @@ class TestMistralProvider:
|
||||||
|
|
||||||
async def test_text_with_image_input(self, project_dir: Path) -> None:
|
async def test_text_with_image_input(self, project_dir: Path) -> None:
|
||||||
_ = (project_dir / "photo.png").write_bytes(b"\x89PNG")
|
_ = (project_dir / "photo.png").write_bytes(b"\x89PNG")
|
||||||
target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"])
|
target_config = GenerateTargetConfig(
|
||||||
|
prompt="Describe this image", inputs=["photo.png"]
|
||||||
|
)
|
||||||
response = _make_text_response("A beautiful photo")
|
response = _make_text_response("A beautiful photo")
|
||||||
|
|
||||||
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
|
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
|
||||||
|
|
@ -330,7 +336,7 @@ class TestMistralProvider:
|
||||||
assert chunks[1].image_url.url.startswith("data:image/png;base64,")
|
assert chunks[1].image_url.url.startswith("data:image/png;base64,")
|
||||||
|
|
||||||
async def test_text_no_choices_raises(self, project_dir: Path) -> None:
|
async def test_text_no_choices_raises(self, project_dir: Path) -> None:
|
||||||
target_config = TargetConfig(prompt="x")
|
target_config = GenerateTargetConfig(prompt="x")
|
||||||
response = MagicMock()
|
response = MagicMock()
|
||||||
response.choices = []
|
response.choices = []
|
||||||
|
|
||||||
|
|
@ -348,7 +354,7 @@ class TestMistralProvider:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_text_empty_content_raises(self, project_dir: Path) -> None:
|
async def test_text_empty_content_raises(self, project_dir: Path) -> None:
|
||||||
target_config = TargetConfig(prompt="x")
|
target_config = GenerateTargetConfig(prompt="x")
|
||||||
response = _make_text_response(None)
|
response = _make_text_response(None)
|
||||||
|
|
||||||
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
|
with patch("hokusai.providers.mistral.Mistral") as mock_cls:
|
||||||
|
|
@ -369,7 +375,7 @@ class TestMistralProvider:
|
||||||
_ = (project_dir / "b.txt").write_text("content B")
|
_ = (project_dir / "b.txt").write_text("content B")
|
||||||
_ = (project_dir / "c.png").write_bytes(b"\x89PNG")
|
_ = (project_dir / "c.png").write_bytes(b"\x89PNG")
|
||||||
|
|
||||||
target_config = TargetConfig(
|
target_config = GenerateTargetConfig(
|
||||||
prompt="Combine all", inputs=["a.txt", "b.txt", "c.png"]
|
prompt="Combine all", inputs=["a.txt", "b.txt", "c.png"]
|
||||||
)
|
)
|
||||||
response = _make_text_response("Combined")
|
response = _make_text_response("Combined")
|
||||||
|
|
@ -400,7 +406,7 @@ class TestMistralProvider:
|
||||||
async def test_text_with_reference_images(self, project_dir: Path) -> None:
|
async def test_text_with_reference_images(self, project_dir: Path) -> None:
|
||||||
_ = (project_dir / "ref.png").write_bytes(b"\x89PNG")
|
_ = (project_dir / "ref.png").write_bytes(b"\x89PNG")
|
||||||
|
|
||||||
target_config = TargetConfig(
|
target_config = GenerateTargetConfig(
|
||||||
prompt="Describe the style", reference_images=["ref.png"]
|
prompt="Describe the style", reference_images=["ref.png"]
|
||||||
)
|
)
|
||||||
response = _make_text_response("A stylized image")
|
response = _make_text_response("A stylized image")
|
||||||
|
|
@ -455,7 +461,9 @@ class TestOpenAIImageProvider:
|
||||||
ref = project_dir / "ref.png"
|
ref = project_dir / "ref.png"
|
||||||
_ = ref.write_bytes(b"reference data")
|
_ = ref.write_bytes(b"reference data")
|
||||||
|
|
||||||
target_config = TargetConfig(prompt="Edit this", reference_images=["ref.png"])
|
target_config = GenerateTargetConfig(
|
||||||
|
prompt="Edit this", reference_images=["ref.png"]
|
||||||
|
)
|
||||||
b64 = base64.b64encode(image_bytes).decode()
|
b64 = base64.b64encode(image_bytes).decode()
|
||||||
mock_client = _make_openai_mock(b64)
|
mock_client = _make_openai_mock(b64)
|
||||||
|
|
||||||
|
|
@ -487,7 +495,7 @@ class TestOpenAIImageProvider:
|
||||||
_ = ref1.write_bytes(b"ref1 data")
|
_ = ref1.write_bytes(b"ref1 data")
|
||||||
_ = ref2.write_bytes(b"ref2 data")
|
_ = ref2.write_bytes(b"ref2 data")
|
||||||
|
|
||||||
target_config = TargetConfig(
|
target_config = GenerateTargetConfig(
|
||||||
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
|
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
|
||||||
)
|
)
|
||||||
b64 = base64.b64encode(image_bytes).decode()
|
b64 = base64.b64encode(image_bytes).decode()
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import pytest
|
||||||
|
|
||||||
from hokusai.config import (
|
from hokusai.config import (
|
||||||
Defaults,
|
Defaults,
|
||||||
TargetConfig,
|
GenerateTargetConfig,
|
||||||
)
|
)
|
||||||
from hokusai.providers.models import Capability
|
from hokusai.providers.models import Capability
|
||||||
from hokusai.resolve import infer_required_capabilities, resolve_model
|
from hokusai.resolve import infer_required_capabilities, resolve_model
|
||||||
|
|
@ -16,37 +16,37 @@ class TestInferRequiredCapabilities:
|
||||||
"""Test capability inference from file extensions and target config."""
|
"""Test capability inference from file extensions and target config."""
|
||||||
|
|
||||||
def test_plain_image(self) -> None:
|
def test_plain_image(self) -> None:
|
||||||
target = TargetConfig(prompt="x")
|
target = GenerateTargetConfig(prompt="x")
|
||||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
{Capability.TEXT_TO_IMAGE}
|
{Capability.TEXT_TO_IMAGE}
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
|
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
|
||||||
def test_image_extensions(self, name: str) -> None:
|
def test_image_extensions(self, name: str) -> None:
|
||||||
target = TargetConfig(prompt="x")
|
target = GenerateTargetConfig(prompt="x")
|
||||||
caps = infer_required_capabilities(name, target)
|
caps = infer_required_capabilities(name, target)
|
||||||
assert Capability.TEXT_TO_IMAGE in caps
|
assert Capability.TEXT_TO_IMAGE in caps
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
|
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
|
||||||
def test_case_insensitive(self, name: str) -> None:
|
def test_case_insensitive(self, name: str) -> None:
|
||||||
target = TargetConfig(prompt="x")
|
target = GenerateTargetConfig(prompt="x")
|
||||||
caps = infer_required_capabilities(name, target)
|
caps = infer_required_capabilities(name, target)
|
||||||
assert Capability.TEXT_TO_IMAGE in caps
|
assert Capability.TEXT_TO_IMAGE in caps
|
||||||
|
|
||||||
def test_image_with_reference_images(self) -> None:
|
def test_image_with_reference_images(self) -> None:
|
||||||
target = TargetConfig(prompt="x", reference_images=["ref.png"])
|
target = GenerateTargetConfig(prompt="x", reference_images=["ref.png"])
|
||||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
|
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_image_with_control_images(self) -> None:
|
def test_image_with_control_images(self) -> None:
|
||||||
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
|
target = GenerateTargetConfig(prompt="x", control_images=["ctrl.png"])
|
||||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
|
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_image_with_both(self) -> None:
|
def test_image_with_both(self) -> None:
|
||||||
target = TargetConfig(
|
target = GenerateTargetConfig(
|
||||||
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
|
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
|
||||||
)
|
)
|
||||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
|
@ -59,35 +59,35 @@ class TestInferRequiredCapabilities:
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
||||||
def test_text_extensions(self, name: str) -> None:
|
def test_text_extensions(self, name: str) -> None:
|
||||||
target = TargetConfig(prompt="x")
|
target = GenerateTargetConfig(prompt="x")
|
||||||
caps = infer_required_capabilities(name, target)
|
caps = infer_required_capabilities(name, target)
|
||||||
assert caps == frozenset({Capability.TEXT_GENERATION})
|
assert caps == frozenset({Capability.TEXT_GENERATION})
|
||||||
|
|
||||||
def test_text_with_text_inputs(self) -> None:
|
def test_text_with_text_inputs(self) -> None:
|
||||||
target = TargetConfig(prompt="x", inputs=["data.txt"])
|
target = GenerateTargetConfig(prompt="x", inputs=["data.txt"])
|
||||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||||
{Capability.TEXT_GENERATION}
|
{Capability.TEXT_GENERATION}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_text_with_image_input(self) -> None:
|
def test_text_with_image_input(self) -> None:
|
||||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
target = GenerateTargetConfig(prompt="x", inputs=["photo.png"])
|
||||||
assert infer_required_capabilities("out.txt", target) == frozenset(
|
assert infer_required_capabilities("out.txt", target) == frozenset(
|
||||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_text_with_image_reference(self) -> None:
|
def test_text_with_image_reference(self) -> None:
|
||||||
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
|
target = GenerateTargetConfig(prompt="x", reference_images=["ref.jpg"])
|
||||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unsupported_extension_raises(self) -> None:
|
def test_unsupported_extension_raises(self) -> None:
|
||||||
target = TargetConfig(prompt="x")
|
target = GenerateTargetConfig(prompt="x")
|
||||||
with pytest.raises(ValueError, match="unsupported extension"):
|
with pytest.raises(ValueError, match="unsupported extension"):
|
||||||
_ = infer_required_capabilities("data.csv", target)
|
_ = infer_required_capabilities("data.csv", target)
|
||||||
|
|
||||||
def test_no_extension_raises(self) -> None:
|
def test_no_extension_raises(self) -> None:
|
||||||
target = TargetConfig(prompt="x")
|
target = GenerateTargetConfig(prompt="x")
|
||||||
with pytest.raises(ValueError, match="unsupported extension"):
|
with pytest.raises(ValueError, match="unsupported extension"):
|
||||||
_ = infer_required_capabilities("Makefile", target)
|
_ = infer_required_capabilities("Makefile", target)
|
||||||
|
|
||||||
|
|
@ -96,50 +96,52 @@ class TestResolveModel:
|
||||||
"""Test model resolution with capability validation."""
|
"""Test model resolution with capability validation."""
|
||||||
|
|
||||||
def test_explicit_model_wins(self) -> None:
|
def test_explicit_model_wins(self) -> None:
|
||||||
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
target = GenerateTargetConfig(prompt="x", model="mistral-small-latest")
|
||||||
result = resolve_model("out.txt", target, Defaults())
|
result = resolve_model("out.txt", target, Defaults())
|
||||||
assert result.name == "mistral-small-latest"
|
assert result.name == "mistral-small-latest"
|
||||||
|
|
||||||
def test_default_text_model(self) -> None:
|
def test_default_text_model(self) -> None:
|
||||||
target = TargetConfig(prompt="x")
|
target = GenerateTargetConfig(prompt="x")
|
||||||
defaults = Defaults(text_model="mistral-large-latest")
|
defaults = Defaults(text_model="mistral-large-latest")
|
||||||
result = resolve_model("out.md", target, defaults)
|
result = resolve_model("out.md", target, defaults)
|
||||||
assert result.name == "mistral-large-latest"
|
assert result.name == "mistral-large-latest"
|
||||||
|
|
||||||
def test_default_image_model(self) -> None:
|
def test_default_image_model(self) -> None:
|
||||||
target = TargetConfig(prompt="x")
|
target = GenerateTargetConfig(prompt="x")
|
||||||
defaults = Defaults(image_model="flux-dev")
|
defaults = Defaults(image_model="flux-dev")
|
||||||
result = resolve_model("out.png", target, defaults)
|
result = resolve_model("out.png", target, defaults)
|
||||||
assert result.name == "flux-dev"
|
assert result.name == "flux-dev"
|
||||||
|
|
||||||
def test_unknown_model_raises(self) -> None:
|
def test_unknown_model_raises(self) -> None:
|
||||||
target = TargetConfig(prompt="x", model="nonexistent-model")
|
target = GenerateTargetConfig(prompt="x", model="nonexistent-model")
|
||||||
with pytest.raises(ValueError, match="Unknown model"):
|
with pytest.raises(ValueError, match="Unknown model"):
|
||||||
_ = resolve_model("out.txt", target, Defaults())
|
_ = resolve_model("out.txt", target, Defaults())
|
||||||
|
|
||||||
def test_explicit_model_missing_capability_raises(self) -> None:
|
def test_explicit_model_missing_capability_raises(self) -> None:
|
||||||
# flux-dev does not support reference images
|
# flux-dev does not support reference images
|
||||||
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
|
target = GenerateTargetConfig(
|
||||||
|
prompt="x", model="flux-dev", reference_images=["r.png"]
|
||||||
|
)
|
||||||
with pytest.raises(ValueError, match="lacks required capabilities"):
|
with pytest.raises(ValueError, match="lacks required capabilities"):
|
||||||
_ = resolve_model("out.png", target, Defaults())
|
_ = resolve_model("out.png", target, Defaults())
|
||||||
|
|
||||||
def test_default_fallback_for_reference_images(self) -> None:
|
def test_default_fallback_for_reference_images(self) -> None:
|
||||||
# Default flux-dev lacks reference_images, should fall back to a capable model
|
# Default flux-dev lacks reference_images, should fall back to a capable model
|
||||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
target = GenerateTargetConfig(prompt="x", reference_images=["r.png"])
|
||||||
defaults = Defaults(image_model="flux-dev")
|
defaults = Defaults(image_model="flux-dev")
|
||||||
result = resolve_model("out.png", target, defaults)
|
result = resolve_model("out.png", target, defaults)
|
||||||
assert Capability.REFERENCE_IMAGES in result.capabilities
|
assert Capability.REFERENCE_IMAGES in result.capabilities
|
||||||
|
|
||||||
def test_default_fallback_for_vision(self) -> None:
|
def test_default_fallback_for_vision(self) -> None:
|
||||||
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
|
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
|
||||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
target = GenerateTargetConfig(prompt="x", inputs=["photo.png"])
|
||||||
defaults = Defaults(text_model="mistral-large-latest")
|
defaults = Defaults(text_model="mistral-large-latest")
|
||||||
result = resolve_model("out.txt", target, defaults)
|
result = resolve_model("out.txt", target, defaults)
|
||||||
assert Capability.VISION in result.capabilities
|
assert Capability.VISION in result.capabilities
|
||||||
|
|
||||||
def test_default_preferred_when_capable(self) -> None:
|
def test_default_preferred_when_capable(self) -> None:
|
||||||
# Default flux-2-pro already supports reference_images, should be used directly
|
# Default flux-2-pro already supports reference_images, should be used directly
|
||||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
target = GenerateTargetConfig(prompt="x", reference_images=["r.png"])
|
||||||
defaults = Defaults(image_model="flux-2-pro")
|
defaults = Defaults(image_model="flux-2-pro")
|
||||||
result = resolve_model("out.png", target, defaults)
|
result = resolve_model("out.png", target, defaults)
|
||||||
assert result.name == "flux-2-pro"
|
assert result.name == "flux-2-pro"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue