Download targets now store only 'download: <url>' in the state file
instead of using 'prompt' and 'model: __download__' as a workaround.
Also use exclude_defaults=True when serializing state to omit empty
fields like input_hashes: {} and extra_params: {}.
366 lines
12 KiB
Python
366 lines
12 KiB
Python
"""Build orchestrator: incremental builds with parallel execution."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import enum
|
|
import os
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
|
|
from hokusai.archive import archive_file
|
|
from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig
|
|
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
|
|
from hokusai.prompt import extract_placeholder_files, resolve_prompt
|
|
from hokusai.providers import Provider
|
|
from hokusai.providers.blackforest import BlackForestProvider
|
|
from hokusai.providers.mistral import MistralProvider
|
|
from hokusai.providers.openai_image import OpenAIImageProvider
|
|
from hokusai.providers.openai_text import OpenAITextProvider
|
|
from hokusai.resolve import resolve_model
|
|
from hokusai.state import (
|
|
BuildState,
|
|
is_target_dirty,
|
|
load_state,
|
|
record_target_state,
|
|
save_state,
|
|
)
|
|
|
|
|
|
class BuildEvent(enum.Enum):
|
|
"""Events emitted during a build for progress reporting."""
|
|
|
|
TARGET_SKIPPED = "skipped"
|
|
TARGET_BUILDING = "building"
|
|
TARGET_OK = "ok"
|
|
TARGET_FAILED = "failed"
|
|
TARGET_DEP_FAILED = "dep_failed"
|
|
TARGET_NO_PROVIDER = "no_provider"
|
|
|
|
|
|
ProgressCallback = Callable[[BuildEvent, str, str], None]
|
|
"""Signature: (event, target_name, detail_message)."""
|
|
|
|
|
|
def _noop_callback(_event: BuildEvent, _name: str, _detail: str) -> None:
|
|
"""Default no-op progress callback."""
|
|
|
|
|
|
@dataclass
|
|
class BuildResult:
|
|
"""Summary of a build run."""
|
|
|
|
built: list[str] = field(default_factory=list)
|
|
skipped: list[str] = field(default_factory=list)
|
|
failed: dict[str, str] = field(default_factory=dict)
|
|
total_targets: int = 0
|
|
|
|
|
|
def _collect_dep_files(
|
|
target_name: str, config: ProjectConfig, project_dir: Path
|
|
) -> list[Path]:
|
|
"""Collect all dependency file paths for a target."""
|
|
target_cfg = config.targets[target_name]
|
|
if not isinstance(target_cfg, GenerateTargetConfig):
|
|
return []
|
|
deps: list[str] = list(target_cfg.inputs)
|
|
deps.extend(target_cfg.reference_images)
|
|
deps.extend(target_cfg.control_images)
|
|
deps.extend(extract_placeholder_files(target_cfg.prompt))
|
|
return [project_dir / d for d in deps]
|
|
|
|
|
|
def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]:
|
|
"""Collect extra parameters that affect rebuild decisions."""
|
|
target_cfg = config.targets[target_name]
|
|
if not isinstance(target_cfg, GenerateTargetConfig):
|
|
return {}
|
|
params: dict[str, object] = {}
|
|
if target_cfg.width is not None:
|
|
params["width"] = target_cfg.width
|
|
if target_cfg.height is not None:
|
|
params["height"] = target_cfg.height
|
|
if target_cfg.reference_images:
|
|
params["reference_images"] = tuple(target_cfg.reference_images)
|
|
if target_cfg.control_images:
|
|
params["control_images"] = tuple(target_cfg.control_images)
|
|
return params
|
|
|
|
|
|
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
|
|
"""Collect all dependency names (inputs + reference_images + control_images + placeholders)."""
|
|
target_cfg = config.targets[target_name]
|
|
if not isinstance(target_cfg, GenerateTargetConfig):
|
|
return []
|
|
deps: list[str] = list(target_cfg.inputs)
|
|
deps.extend(target_cfg.reference_images)
|
|
deps.extend(target_cfg.control_images)
|
|
deps.extend(extract_placeholder_files(target_cfg.prompt))
|
|
return deps
|
|
|
|
|
|
def _create_providers() -> list[Provider]:
|
|
"""Create provider instances from environment variables."""
|
|
providers: list[Provider] = []
|
|
bfl_key = os.environ.get("BFL_API_KEY", "")
|
|
if bfl_key:
|
|
providers.append(BlackForestProvider(api_key=bfl_key))
|
|
mistral_key = os.environ.get("MISTRAL_API_KEY", "")
|
|
if mistral_key:
|
|
providers.append(MistralProvider(api_key=mistral_key))
|
|
openai_key = os.environ.get("OPENAI_API_KEY", "")
|
|
if openai_key:
|
|
providers.append(OpenAITextProvider(api_key=openai_key))
|
|
providers.append(OpenAIImageProvider(api_key=openai_key))
|
|
return providers
|
|
|
|
|
|
def _build_provider_index(providers: list[Provider]) -> dict[str, Provider]:
|
|
"""Build a model-name → provider lookup from a list of providers."""
|
|
index: dict[str, Provider] = {}
|
|
for provider in providers:
|
|
for model in provider.get_provided_models():
|
|
index[model.name] = provider
|
|
return index
|
|
|
|
|
|
async def _download_target(
|
|
target_name: str,
|
|
target_cfg: DownloadTargetConfig,
|
|
project_dir: Path,
|
|
) -> None:
|
|
"""Download a file from a URL and write it to *project_dir / target_name*."""
|
|
async with httpx.AsyncClient(headers={"User-Agent": "hokusai"}) as client:
|
|
response = await client.get(target_cfg.download, follow_redirects=True)
|
|
_ = response.raise_for_status()
|
|
_ = (project_dir / target_name).write_bytes(response.content)
|
|
|
|
|
|
async def _build_single_target(
|
|
target_name: str,
|
|
config: ProjectConfig,
|
|
project_dir: Path,
|
|
provider_index: dict[str, Provider],
|
|
) -> None:
|
|
"""Build a single target by dispatching to the appropriate provider."""
|
|
output_path = project_dir / target_name
|
|
|
|
# Ensure parent directories exist for targets in subfolders.
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Archive the existing artifact before overwriting.
|
|
if config.archive_folder is not None:
|
|
_ = archive_file(output_path, project_dir, config.archive_folder)
|
|
|
|
target_cfg = config.targets[target_name]
|
|
|
|
if isinstance(target_cfg, DownloadTargetConfig):
|
|
await _download_target(target_name, target_cfg, project_dir)
|
|
return
|
|
|
|
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
|
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
|
|
|
provider = provider_index[model_info.name]
|
|
await provider.generate(
|
|
target_name=target_name,
|
|
target_config=target_cfg,
|
|
resolved_prompt=resolved_prompt,
|
|
resolved_model=model_info,
|
|
project_dir=project_dir,
|
|
)
|
|
|
|
|
|
async def run_build(
|
|
config: ProjectConfig,
|
|
project_dir: Path,
|
|
project_name: str,
|
|
target: str | None = None,
|
|
on_progress: ProgressCallback = _noop_callback,
|
|
) -> BuildResult:
|
|
"""Execute the build.
|
|
|
|
If *target* is specified, only build that target and its transitive
|
|
dependencies. Otherwise build all targets.
|
|
|
|
Execution proceeds in topological generations — each generation is a
|
|
set of independent targets that run concurrently via
|
|
:func:`asyncio.gather`.
|
|
"""
|
|
result = BuildResult()
|
|
providers = _create_providers()
|
|
provider_index = _build_provider_index(providers)
|
|
|
|
graph = build_graph(config, project_dir)
|
|
|
|
if target is not None:
|
|
if target not in config.targets:
|
|
msg = f"Unknown target: '{target}'"
|
|
raise ValueError(msg)
|
|
graph = get_subgraph_for_target(graph, target)
|
|
|
|
state = load_state(project_dir, project_name)
|
|
generations = get_build_order(graph)
|
|
target_names = set(config.targets)
|
|
|
|
# Count total buildable targets for progress reporting.
|
|
result.total_targets = sum(
|
|
1 for gen in generations for n in gen if n in target_names
|
|
)
|
|
|
|
for generation in generations:
|
|
targets_in_gen = [n for n in generation if n in target_names]
|
|
|
|
dirty_targets: list[str] = []
|
|
for name in targets_in_gen:
|
|
if _should_skip_failed_dep(name, config, result):
|
|
result.failed[name] = "Dependency failed"
|
|
on_progress(BuildEvent.TARGET_DEP_FAILED, name, "Dependency failed")
|
|
continue
|
|
|
|
if _is_dirty(name, config, project_dir, state):
|
|
if not _has_provider(name, config, provider_index, result, on_progress):
|
|
continue
|
|
dirty_targets.append(name)
|
|
else:
|
|
result.skipped.append(name)
|
|
on_progress(BuildEvent.TARGET_SKIPPED, name, "up to date")
|
|
|
|
if not dirty_targets:
|
|
continue
|
|
|
|
for name in dirty_targets:
|
|
on_progress(BuildEvent.TARGET_BUILDING, name, "")
|
|
|
|
outcomes = await _build_generation(
|
|
dirty_targets, config, project_dir, provider_index
|
|
)
|
|
|
|
_process_outcomes(outcomes, config, project_dir, state, result, on_progress)
|
|
save_state(state, project_dir, project_name)
|
|
|
|
return result
|
|
|
|
|
|
def _should_skip_failed_dep(
|
|
target_name: str, config: ProjectConfig, result: BuildResult
|
|
) -> bool:
|
|
"""Check if any dependency of a target has already failed."""
|
|
return any(d in result.failed for d in _collect_all_deps(target_name, config))
|
|
|
|
|
|
def _is_dirty(
|
|
target_name: str,
|
|
config: ProjectConfig,
|
|
project_dir: Path,
|
|
state: BuildState,
|
|
) -> bool:
|
|
"""Check if a target needs rebuilding."""
|
|
target_cfg = config.targets[target_name]
|
|
|
|
if isinstance(target_cfg, DownloadTargetConfig):
|
|
return is_target_dirty(
|
|
target_name,
|
|
download=target_cfg.download,
|
|
state=state,
|
|
project_dir=project_dir,
|
|
)
|
|
|
|
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
|
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
|
dep_files = _collect_dep_files(target_name, config, project_dir)
|
|
extra = _collect_extra_params(target_name, config)
|
|
|
|
return is_target_dirty(
|
|
target_name,
|
|
resolved_prompt=resolved_prompt,
|
|
model=model_info.name,
|
|
dep_files=dep_files,
|
|
extra_params=extra,
|
|
state=state,
|
|
project_dir=project_dir,
|
|
)
|
|
|
|
|
|
def _has_provider(
|
|
target_name: str,
|
|
config: ProjectConfig,
|
|
provider_index: dict[str, Provider],
|
|
result: BuildResult,
|
|
on_progress: ProgressCallback = _noop_callback,
|
|
) -> bool:
|
|
"""Check that the required provider is available; record failure if not."""
|
|
target_cfg = config.targets[target_name]
|
|
if isinstance(target_cfg, DownloadTargetConfig):
|
|
return True
|
|
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
|
if model_info.name not in provider_index:
|
|
msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables"
|
|
result.failed[target_name] = msg
|
|
on_progress(BuildEvent.TARGET_NO_PROVIDER, target_name, msg)
|
|
return False
|
|
return True
|
|
|
|
|
|
async def _build_generation(
|
|
dirty_targets: list[str],
|
|
config: ProjectConfig,
|
|
project_dir: Path,
|
|
provider_index: dict[str, Provider],
|
|
) -> list[tuple[str, Exception | None]]:
|
|
"""Build all dirty targets in a generation concurrently."""
|
|
|
|
async def _build_one(name: str) -> tuple[str, Exception | None]:
|
|
try:
|
|
await _build_single_target(name, config, project_dir, provider_index)
|
|
except Exception as exc: # noqa: BLE001
|
|
return (name, exc)
|
|
return (name, None)
|
|
|
|
return list(await asyncio.gather(*[_build_one(t) for t in dirty_targets]))
|
|
|
|
|
|
def _process_outcomes(
|
|
outcomes: list[tuple[str, Exception | None]],
|
|
config: ProjectConfig,
|
|
project_dir: Path,
|
|
state: BuildState,
|
|
result: BuildResult,
|
|
on_progress: ProgressCallback = _noop_callback,
|
|
) -> None:
|
|
"""Process build outcomes: record state for successes, log failures."""
|
|
for name, error in outcomes:
|
|
if error is not None:
|
|
result.failed[name] = str(error)
|
|
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
|
|
else:
|
|
target_cfg = config.targets[name]
|
|
|
|
if isinstance(target_cfg, DownloadTargetConfig):
|
|
record_target_state(
|
|
name,
|
|
download=target_cfg.download,
|
|
state=state,
|
|
project_dir=project_dir,
|
|
)
|
|
else:
|
|
model_info = resolve_model(name, target_cfg, config.defaults)
|
|
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
|
dep_files = _collect_dep_files(name, config, project_dir)
|
|
extra = _collect_extra_params(name, config)
|
|
|
|
record_target_state(
|
|
name,
|
|
resolved_prompt=resolved_prompt,
|
|
model=model_info.name,
|
|
dep_files=dep_files,
|
|
extra_params=extra,
|
|
state=state,
|
|
project_dir=project_dir,
|
|
)
|
|
|
|
result.built.append(name)
|
|
on_progress(BuildEvent.TARGET_OK, name, "")
|