"""Build orchestrator: incremental builds with parallel execution.""" from __future__ import annotations import asyncio import enum import os from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from bulkgen.config import ProjectConfig from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target from bulkgen.providers import Provider from bulkgen.providers.blackforest import BlackForestProvider from bulkgen.providers.mistral import MistralProvider from bulkgen.providers.openai_image import OpenAIImageProvider from bulkgen.providers.openai_text import OpenAITextProvider from bulkgen.resolve import resolve_model from bulkgen.state import ( BuildState, is_target_dirty, load_state, record_target_state, save_state, ) class BuildEvent(enum.Enum): """Events emitted during a build for progress reporting.""" TARGET_SKIPPED = "skipped" TARGET_BUILDING = "building" TARGET_OK = "ok" TARGET_FAILED = "failed" TARGET_DEP_FAILED = "dep_failed" TARGET_NO_PROVIDER = "no_provider" ProgressCallback = Callable[[BuildEvent, str, str], None] """Signature: (event, target_name, detail_message).""" def _noop_callback(_event: BuildEvent, _name: str, _detail: str) -> None: """Default no-op progress callback.""" @dataclass class BuildResult: """Summary of a build run.""" built: list[str] = field(default_factory=list) skipped: list[str] = field(default_factory=list) failed: dict[str, str] = field(default_factory=dict) total_targets: int = 0 def _resolve_prompt(prompt_value: str, project_dir: Path) -> str: """Resolve a prompt: read from file if the path exists, otherwise use as-is.""" candidate = project_dir / prompt_value if candidate.is_file(): return candidate.read_text() return prompt_value def _collect_dep_files( target_name: str, config: ProjectConfig, project_dir: Path ) -> list[Path]: """Collect all dependency file paths for a target.""" target_cfg = config.targets[target_name] deps: list[str] = list(target_cfg.inputs) deps.extend(target_cfg.reference_images) deps.extend(target_cfg.control_images) return [project_dir / d for d in deps] def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]: """Collect extra parameters that affect rebuild decisions.""" target_cfg = config.targets[target_name] params: dict[str, object] = {} if target_cfg.width is not None: params["width"] = target_cfg.width if target_cfg.height is not None: params["height"] = target_cfg.height if target_cfg.reference_images: params["reference_images"] = tuple(target_cfg.reference_images) if target_cfg.control_images: params["control_images"] = tuple(target_cfg.control_images) return params def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]: """Collect all dependency names (inputs + reference_images + control_images).""" target_cfg = config.targets[target_name] deps: list[str] = list(target_cfg.inputs) deps.extend(target_cfg.reference_images) deps.extend(target_cfg.control_images) return deps def _create_providers() -> list[Provider]: """Create provider instances from environment variables.""" providers: list[Provider] = [] bfl_key = os.environ.get("BFL_API_KEY", "") if bfl_key: providers.append(BlackForestProvider(api_key=bfl_key)) mistral_key = os.environ.get("MISTRAL_API_KEY", "") if mistral_key: providers.append(MistralProvider(api_key=mistral_key)) openai_key = os.environ.get("OPENAI_API_KEY", "") if openai_key: providers.append(OpenAITextProvider(api_key=openai_key)) providers.append(OpenAIImageProvider(api_key=openai_key)) return providers def _build_provider_index(providers: list[Provider]) -> dict[str, Provider]: """Build a model-name → provider lookup from a list of providers.""" index: dict[str, Provider] = {} for provider in providers: for model in provider.get_provided_models(): index[model.name] = provider return index async def _build_single_target( target_name: str, config: ProjectConfig, project_dir: Path, provider_index: dict[str, Provider], ) -> None: """Build a single target by dispatching to the appropriate provider.""" target_cfg = config.targets[target_name] model_info = resolve_model(target_name, target_cfg, config.defaults) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) provider = provider_index[model_info.name] await provider.generate( target_name=target_name, target_config=target_cfg, resolved_prompt=resolved_prompt, resolved_model=model_info, project_dir=project_dir, ) async def run_build( config: ProjectConfig, project_dir: Path, target: str | None = None, on_progress: ProgressCallback = _noop_callback, ) -> BuildResult: """Execute the build. If *target* is specified, only build that target and its transitive dependencies. Otherwise build all targets. Execution proceeds in topological generations — each generation is a set of independent targets that run concurrently via :func:`asyncio.gather`. """ result = BuildResult() providers = _create_providers() provider_index = _build_provider_index(providers) graph = build_graph(config, project_dir) if target is not None: if target not in config.targets: msg = f"Unknown target: '{target}'" raise ValueError(msg) graph = get_subgraph_for_target(graph, target) state = load_state(project_dir) generations = get_build_order(graph) target_names = set(config.targets) # Count total buildable targets for progress reporting. result.total_targets = sum( 1 for gen in generations for n in gen if n in target_names ) for generation in generations: targets_in_gen = [n for n in generation if n in target_names] dirty_targets: list[str] = [] for name in targets_in_gen: if _should_skip_failed_dep(name, config, result): result.failed[name] = "Dependency failed" on_progress(BuildEvent.TARGET_DEP_FAILED, name, "Dependency failed") continue if _is_dirty(name, config, project_dir, state): if not _has_provider(name, config, provider_index, result, on_progress): continue dirty_targets.append(name) else: result.skipped.append(name) on_progress(BuildEvent.TARGET_SKIPPED, name, "up to date") if not dirty_targets: continue for name in dirty_targets: on_progress(BuildEvent.TARGET_BUILDING, name, "") outcomes = await _build_generation( dirty_targets, config, project_dir, provider_index ) _process_outcomes(outcomes, config, project_dir, state, result, on_progress) save_state(state, project_dir) return result def _should_skip_failed_dep( target_name: str, config: ProjectConfig, result: BuildResult ) -> bool: """Check if any dependency of a target has already failed.""" return any(d in result.failed for d in _collect_all_deps(target_name, config)) def _is_dirty( target_name: str, config: ProjectConfig, project_dir: Path, state: BuildState, ) -> bool: """Check if a target needs rebuilding.""" target_cfg = config.targets[target_name] model_info = resolve_model(target_name, target_cfg, config.defaults) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) dep_files = _collect_dep_files(target_name, config, project_dir) extra = _collect_extra_params(target_name, config) return is_target_dirty( target_name, resolved_prompt=resolved_prompt, model=model_info.name, dep_files=dep_files, extra_params=extra, state=state, project_dir=project_dir, ) def _has_provider( target_name: str, config: ProjectConfig, provider_index: dict[str, Provider], result: BuildResult, on_progress: ProgressCallback = _noop_callback, ) -> bool: """Check that the required provider is available; record failure if not.""" target_cfg = config.targets[target_name] model_info = resolve_model(target_name, target_cfg, config.defaults) if model_info.name not in provider_index: msg = f"No provider available for model '{model_info.name}' (provider: {model_info.provider}) — check API key environment variables" result.failed[target_name] = msg on_progress(BuildEvent.TARGET_NO_PROVIDER, target_name, msg) return False return True async def _build_generation( dirty_targets: list[str], config: ProjectConfig, project_dir: Path, provider_index: dict[str, Provider], ) -> list[tuple[str, Exception | None]]: """Build all dirty targets in a generation concurrently.""" async def _build_one(name: str) -> tuple[str, Exception | None]: try: await _build_single_target(name, config, project_dir, provider_index) except Exception as exc: # noqa: BLE001 return (name, exc) return (name, None) return list(await asyncio.gather(*[_build_one(t) for t in dirty_targets])) def _process_outcomes( outcomes: list[tuple[str, Exception | None]], config: ProjectConfig, project_dir: Path, state: BuildState, result: BuildResult, on_progress: ProgressCallback = _noop_callback, ) -> None: """Process build outcomes: record state for successes, log failures.""" for name, error in outcomes: if error is not None: result.failed[name] = str(error) on_progress(BuildEvent.TARGET_FAILED, name, str(error)) else: target_cfg = config.targets[name] model_info = resolve_model(name, target_cfg, config.defaults) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) dep_files = _collect_dep_files(name, config, project_dir) extra = _collect_extra_params(name, config) record_target_state( name, resolved_prompt=resolved_prompt, model=model_info.name, dep_files=dep_files, extra_params=extra, state=state, project_dir=project_dir, ) result.built.append(name) on_progress(BuildEvent.TARGET_OK, name, "")