feat: add download target type for fetching files from URLs
This commit is contained in:
parent
a4600df4d5
commit
c1ad6e6e3c
14 changed files with 296 additions and 74 deletions
|
|
@ -9,7 +9,9 @@ from collections.abc import Callable
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from hokusai.config import ProjectConfig
|
||||
import httpx
|
||||
|
||||
from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig
|
||||
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
|
||||
from hokusai.providers import Provider
|
||||
from hokusai.providers.blackforest import BlackForestProvider
|
||||
|
|
@ -73,6 +75,8 @@ def _collect_dep_files(
|
|||
) -> list[Path]:
|
||||
"""Collect all dependency file paths for a target."""
|
||||
target_cfg = config.targets[target_name]
|
||||
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||
return []
|
||||
deps: list[str] = list(target_cfg.inputs)
|
||||
deps.extend(target_cfg.reference_images)
|
||||
deps.extend(target_cfg.control_images)
|
||||
|
|
@ -82,6 +86,8 @@ def _collect_dep_files(
|
|||
def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]:
|
||||
"""Collect extra parameters that affect rebuild decisions."""
|
||||
target_cfg = config.targets[target_name]
|
||||
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||
return {}
|
||||
params: dict[str, object] = {}
|
||||
if target_cfg.width is not None:
|
||||
params["width"] = target_cfg.width
|
||||
|
|
@ -97,6 +103,8 @@ def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str,
|
|||
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
|
||||
"""Collect all dependency names (inputs + reference_images + control_images)."""
|
||||
target_cfg = config.targets[target_name]
|
||||
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||
return []
|
||||
deps: list[str] = list(target_cfg.inputs)
|
||||
deps.extend(target_cfg.reference_images)
|
||||
deps.extend(target_cfg.control_images)
|
||||
|
|
@ -128,6 +136,18 @@ def _build_provider_index(providers: list[Provider]) -> dict[str, Provider]:
|
|||
return index
|
||||
|
||||
|
||||
async def _download_target(
|
||||
target_name: str,
|
||||
target_cfg: DownloadTargetConfig,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
"""Download a file from a URL and write it to *project_dir / target_name*."""
|
||||
async with httpx.AsyncClient(headers={"User-Agent": "hokusai"}) as client:
|
||||
response = await client.get(target_cfg.download, follow_redirects=True)
|
||||
_ = response.raise_for_status()
|
||||
_ = (project_dir / target_name).write_bytes(response.content)
|
||||
|
||||
|
||||
async def _build_single_target(
|
||||
target_name: str,
|
||||
config: ProjectConfig,
|
||||
|
|
@ -136,6 +156,11 @@ async def _build_single_target(
|
|||
) -> None:
|
||||
"""Build a single target by dispatching to the appropriate provider."""
|
||||
target_cfg = config.targets[target_name]
|
||||
|
||||
if isinstance(target_cfg, DownloadTargetConfig):
|
||||
await _download_target(target_name, target_cfg, project_dir)
|
||||
return
|
||||
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
|
||||
|
|
@ -227,6 +252,9 @@ def _should_skip_failed_dep(
|
|||
return any(d in result.failed for d in _collect_all_deps(target_name, config))
|
||||
|
||||
|
||||
_DOWNLOAD_MODEL_SENTINEL = "__download__"
|
||||
|
||||
|
||||
def _is_dirty(
|
||||
target_name: str,
|
||||
config: ProjectConfig,
|
||||
|
|
@ -235,6 +263,18 @@ def _is_dirty(
|
|||
) -> bool:
|
||||
"""Check if a target needs rebuilding."""
|
||||
target_cfg = config.targets[target_name]
|
||||
|
||||
if isinstance(target_cfg, DownloadTargetConfig):
|
||||
return is_target_dirty(
|
||||
target_name,
|
||||
resolved_prompt=target_cfg.download,
|
||||
model=_DOWNLOAD_MODEL_SENTINEL,
|
||||
dep_files=[],
|
||||
extra_params={},
|
||||
state=state,
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
dep_files = _collect_dep_files(target_name, config, project_dir)
|
||||
|
|
@ -260,6 +300,8 @@ def _has_provider(
|
|||
) -> bool:
|
||||
"""Check that the required provider is available; record failure if not."""
|
||||
target_cfg = config.targets[target_name]
|
||||
if isinstance(target_cfg, DownloadTargetConfig):
|
||||
return True
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
if model_info.name not in provider_index:
|
||||
msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables"
|
||||
|
|
@ -302,15 +344,22 @@ def _process_outcomes(
|
|||
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
|
||||
else:
|
||||
target_cfg = config.targets[name]
|
||||
model_info = resolve_model(name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
|
||||
if isinstance(target_cfg, DownloadTargetConfig):
|
||||
resolved_prompt = target_cfg.download
|
||||
model_name = _DOWNLOAD_MODEL_SENTINEL
|
||||
else:
|
||||
model_info = resolve_model(name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
model_name = model_info.name
|
||||
|
||||
dep_files = _collect_dep_files(name, config, project_dir)
|
||||
extra = _collect_extra_params(name, config)
|
||||
|
||||
record_target_state(
|
||||
name,
|
||||
resolved_prompt=resolved_prompt,
|
||||
model=model_info.name,
|
||||
model=model_name,
|
||||
dep_files=dep_files,
|
||||
extra_params=extra,
|
||||
state=state,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue