feat: add download target type for fetching files from URLs

This commit is contained in:
Konstantin Fickel 2026-02-20 21:02:15 +01:00
parent a4600df4d5
commit c1ad6e6e3c
Signed by: kfickel
GPG key ID: A793722F9933C1A5
14 changed files with 296 additions and 74 deletions

View file

@ -9,7 +9,9 @@ from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from hokusai.config import ProjectConfig
import httpx
from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
from hokusai.providers import Provider
from hokusai.providers.blackforest import BlackForestProvider
@ -73,6 +75,8 @@ def _collect_dep_files(
) -> list[Path]:
"""Collect all dependency file paths for a target."""
target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return []
deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images)
@ -82,6 +86,8 @@ def _collect_dep_files(
def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]:
"""Collect extra parameters that affect rebuild decisions."""
target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return {}
params: dict[str, object] = {}
if target_cfg.width is not None:
params["width"] = target_cfg.width
@ -97,6 +103,8 @@ def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str,
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
"""Collect all dependency names (inputs + reference_images + control_images)."""
target_cfg = config.targets[target_name]
if not isinstance(target_cfg, GenerateTargetConfig):
return []
deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images)
@ -128,6 +136,18 @@ def _build_provider_index(providers: list[Provider]) -> dict[str, Provider]:
return index
async def _download_target(
target_name: str,
target_cfg: DownloadTargetConfig,
project_dir: Path,
) -> None:
"""Download a file from a URL and write it to *project_dir / target_name*."""
async with httpx.AsyncClient(headers={"User-Agent": "hokusai"}) as client:
response = await client.get(target_cfg.download, follow_redirects=True)
_ = response.raise_for_status()
_ = (project_dir / target_name).write_bytes(response.content)
async def _build_single_target(
target_name: str,
config: ProjectConfig,
@ -136,6 +156,11 @@ async def _build_single_target(
) -> None:
"""Build a single target by dispatching to the appropriate provider."""
target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
await _download_target(target_name, target_cfg, project_dir)
return
model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
@ -227,6 +252,9 @@ def _should_skip_failed_dep(
return any(d in result.failed for d in _collect_all_deps(target_name, config))
_DOWNLOAD_MODEL_SENTINEL = "__download__"
def _is_dirty(
target_name: str,
config: ProjectConfig,
@ -235,6 +263,18 @@ def _is_dirty(
) -> bool:
"""Check if a target needs rebuilding."""
target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
return is_target_dirty(
target_name,
resolved_prompt=target_cfg.download,
model=_DOWNLOAD_MODEL_SENTINEL,
dep_files=[],
extra_params={},
state=state,
project_dir=project_dir,
)
model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(target_name, config, project_dir)
@ -260,6 +300,8 @@ def _has_provider(
) -> bool:
"""Check that the required provider is available; record failure if not."""
target_cfg = config.targets[target_name]
if isinstance(target_cfg, DownloadTargetConfig):
return True
model_info = resolve_model(target_name, target_cfg, config.defaults)
if model_info.name not in provider_index:
msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables"
@ -302,15 +344,22 @@ def _process_outcomes(
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
else:
target_cfg = config.targets[name]
model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
if isinstance(target_cfg, DownloadTargetConfig):
resolved_prompt = target_cfg.download
model_name = _DOWNLOAD_MODEL_SENTINEL
else:
model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
model_name = model_info.name
dep_files = _collect_dep_files(name, config, project_dir)
extra = _collect_extra_params(name, config)
record_target_state(
name,
resolved_prompt=resolved_prompt,
model=model_info.name,
model=model_name,
dep_files=dep_files,
extra_params=extra,
state=state,