hokusai/bulkgen/builder.py
Konstantin Fickel d0dac5b1bf
Some checks failed
Continuous Integration / Build Package (push) Successful in 30s
Continuous Integration / Lint, Check & Test (push) Failing after 38s
refactor: move model definitions into providers and extract resolve module
- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider
- Add get_provided_models() abstract method to Provider base class
- Move model lists from models.py into each provider's get_provided_models()
- Add providers/registry.py to aggregate models from all providers
- Extract infer_required_capabilities and resolve_model from config.py to resolve.py
- Update tests to use new names and import paths
2026-02-15 11:03:57 +01:00

308 lines
10 KiB
Python

"""Build orchestrator: incremental builds with parallel execution."""
from __future__ import annotations
import asyncio
import enum
import os
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from bulkgen.config import (
ProjectConfig,
TargetType,
target_type_from_capabilities,
)
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
from bulkgen.providers import Provider
from bulkgen.providers.blackforest import BlackForestProvider
from bulkgen.providers.mistral import MistralProvider
from bulkgen.resolve import infer_required_capabilities, resolve_model
from bulkgen.state import (
BuildState,
is_target_dirty,
load_state,
record_target_state,
save_state,
)
class BuildEvent(enum.Enum):
"""Events emitted during a build for progress reporting."""
TARGET_SKIPPED = "skipped"
TARGET_BUILDING = "building"
TARGET_OK = "ok"
TARGET_FAILED = "failed"
TARGET_DEP_FAILED = "dep_failed"
TARGET_NO_PROVIDER = "no_provider"
ProgressCallback = Callable[[BuildEvent, str, str], None]
"""Signature: (event, target_name, detail_message)."""
def _noop_callback(_event: BuildEvent, _name: str, _detail: str) -> None:
"""Default no-op progress callback."""
@dataclass
class BuildResult:
"""Summary of a build run."""
built: list[str] = field(default_factory=list)
skipped: list[str] = field(default_factory=list)
failed: dict[str, str] = field(default_factory=dict)
total_targets: int = 0
def _resolve_prompt(prompt_value: str, project_dir: Path) -> str:
"""Resolve a prompt: read from file if the path exists, otherwise use as-is."""
candidate = project_dir / prompt_value
if candidate.is_file():
return candidate.read_text()
return prompt_value
def _collect_dep_files(
target_name: str, config: ProjectConfig, project_dir: Path
) -> list[Path]:
"""Collect all dependency file paths for a target."""
target_cfg = config.targets[target_name]
deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images)
return [project_dir / d for d in deps]
def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]:
"""Collect extra parameters that affect rebuild decisions."""
target_cfg = config.targets[target_name]
params: dict[str, object] = {}
if target_cfg.width is not None:
params["width"] = target_cfg.width
if target_cfg.height is not None:
params["height"] = target_cfg.height
if target_cfg.reference_images:
params["reference_images"] = tuple(target_cfg.reference_images)
if target_cfg.control_images:
params["control_images"] = tuple(target_cfg.control_images)
return params
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
"""Collect all dependency names (inputs + reference_images + control_images)."""
target_cfg = config.targets[target_name]
deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images)
return deps
def _create_providers() -> dict[TargetType, Provider]:
"""Create provider instances from environment variables."""
providers: dict[TargetType, Provider] = {}
bfl_key = os.environ.get("BFL_API_KEY", "")
if bfl_key:
providers[TargetType.IMAGE] = BlackForestProvider(api_key=bfl_key)
mistral_key = os.environ.get("MISTRAL_API_KEY", "")
if mistral_key:
providers[TargetType.TEXT] = MistralProvider(api_key=mistral_key)
return providers
async def _build_single_target(
target_name: str,
config: ProjectConfig,
project_dir: Path,
providers: dict[TargetType, Provider],
) -> None:
"""Build a single target by dispatching to the appropriate provider."""
target_cfg = config.targets[target_name]
model_info = resolve_model(target_name, target_cfg, config.defaults)
required = infer_required_capabilities(target_name, target_cfg)
target_type = target_type_from_capabilities(required)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
provider = providers[target_type]
await provider.generate(
target_name=target_name,
target_config=target_cfg,
resolved_prompt=resolved_prompt,
resolved_model=model_info,
project_dir=project_dir,
)
async def run_build(
config: ProjectConfig,
project_dir: Path,
target: str | None = None,
on_progress: ProgressCallback = _noop_callback,
) -> BuildResult:
"""Execute the build.
If *target* is specified, only build that target and its transitive
dependencies. Otherwise build all targets.
Execution proceeds in topological generations — each generation is a
set of independent targets that run concurrently via
:func:`asyncio.gather`.
"""
result = BuildResult()
providers = _create_providers()
graph = build_graph(config, project_dir)
if target is not None:
if target not in config.targets:
msg = f"Unknown target: '{target}'"
raise ValueError(msg)
graph = get_subgraph_for_target(graph, target)
state = load_state(project_dir)
generations = get_build_order(graph)
target_names = set(config.targets)
# Count total buildable targets for progress reporting.
result.total_targets = sum(
1 for gen in generations for n in gen if n in target_names
)
for generation in generations:
targets_in_gen = [n for n in generation if n in target_names]
dirty_targets: list[str] = []
for name in targets_in_gen:
if _should_skip_failed_dep(name, config, result):
result.failed[name] = "Dependency failed"
on_progress(BuildEvent.TARGET_DEP_FAILED, name, "Dependency failed")
continue
if _is_dirty(name, config, project_dir, state):
if not _has_provider(name, config, providers, result, on_progress):
continue
dirty_targets.append(name)
else:
result.skipped.append(name)
on_progress(BuildEvent.TARGET_SKIPPED, name, "up to date")
if not dirty_targets:
continue
for name in dirty_targets:
on_progress(BuildEvent.TARGET_BUILDING, name, "")
outcomes = await _build_generation(
dirty_targets, config, project_dir, providers
)
_process_outcomes(outcomes, config, project_dir, state, result, on_progress)
save_state(state, project_dir)
return result
def _should_skip_failed_dep(
target_name: str, config: ProjectConfig, result: BuildResult
) -> bool:
"""Check if any dependency of a target has already failed."""
return any(d in result.failed for d in _collect_all_deps(target_name, config))
def _is_dirty(
target_name: str,
config: ProjectConfig,
project_dir: Path,
state: BuildState,
) -> bool:
"""Check if a target needs rebuilding."""
target_cfg = config.targets[target_name]
model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(target_name, config, project_dir)
extra = _collect_extra_params(target_name, config)
return is_target_dirty(
target_name,
resolved_prompt=resolved_prompt,
model=model_info.name,
dep_files=dep_files,
extra_params=extra,
state=state,
project_dir=project_dir,
)
def _has_provider(
target_name: str,
config: ProjectConfig,
providers: dict[TargetType, Provider],
result: BuildResult,
on_progress: ProgressCallback = _noop_callback,
) -> bool:
"""Check that the required provider is available; record failure if not."""
target_cfg = config.targets[target_name]
required = infer_required_capabilities(target_name, target_cfg)
target_type = target_type_from_capabilities(required)
if target_type not in providers:
env_var = (
"BFL_API_KEY" if target_type is TargetType.IMAGE else "MISTRAL_API_KEY"
)
msg = f"Missing {env_var} environment variable"
result.failed[target_name] = msg
on_progress(BuildEvent.TARGET_NO_PROVIDER, target_name, msg)
return False
return True
async def _build_generation(
dirty_targets: list[str],
config: ProjectConfig,
project_dir: Path,
providers: dict[TargetType, Provider],
) -> list[tuple[str, Exception | None]]:
"""Build all dirty targets in a generation concurrently."""
async def _build_one(name: str) -> tuple[str, Exception | None]:
try:
await _build_single_target(name, config, project_dir, providers)
except Exception as exc: # noqa: BLE001
return (name, exc)
return (name, None)
return list(await asyncio.gather(*[_build_one(t) for t in dirty_targets]))
def _process_outcomes(
outcomes: list[tuple[str, Exception | None]],
config: ProjectConfig,
project_dir: Path,
state: BuildState,
result: BuildResult,
on_progress: ProgressCallback = _noop_callback,
) -> None:
"""Process build outcomes: record state for successes, log failures."""
for name, error in outcomes:
if error is not None:
result.failed[name] = str(error)
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
else:
target_cfg = config.targets[name]
model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(name, config, project_dir)
extra = _collect_extra_params(name, config)
record_target_state(
name,
resolved_prompt=resolved_prompt,
model=model_info.name,
dep_files=dep_files,
extra_params=extra,
state=state,
project_dir=project_dir,
)
result.built.append(name)
on_progress(BuildEvent.TARGET_OK, name, "")