diff --git a/bulkgen/builder.py b/bulkgen/builder.py new file mode 100644 index 0000000..f93e8f2 --- /dev/null +++ b/bulkgen/builder.py @@ -0,0 +1,271 @@ +"""Build orchestrator: incremental builds with parallel execution.""" + +from __future__ import annotations + +import asyncio +import os +from dataclasses import dataclass, field +from pathlib import Path + +import typer + +from bulkgen.config import ( + ProjectConfig, + TargetType, + infer_target_type, + resolve_model, +) +from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target +from bulkgen.providers import Provider +from bulkgen.providers.image import ImageProvider +from bulkgen.providers.text import TextProvider +from bulkgen.state import ( + BuildState, + is_target_dirty, + load_state, + record_target_state, + save_state, +) + + +@dataclass +class BuildResult: + """Summary of a build run.""" + + built: list[str] = field(default_factory=list) + skipped: list[str] = field(default_factory=list) + failed: dict[str, str] = field(default_factory=dict) + + +def _resolve_prompt(prompt_value: str, project_dir: Path) -> str: + """Resolve a prompt: read from file if the path exists, otherwise use as-is.""" + candidate = project_dir / prompt_value + if candidate.is_file(): + return candidate.read_text() + return prompt_value + + +def _collect_dep_files( + target_name: str, config: ProjectConfig, project_dir: Path +) -> list[Path]: + """Collect all dependency file paths for a target.""" + target_cfg = config.targets[target_name] + deps: list[str] = list(target_cfg.inputs) + if target_cfg.reference_image is not None: + deps.append(target_cfg.reference_image) + deps.extend(target_cfg.control_images) + return [project_dir / d for d in deps] + + +def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str, object]: + """Collect extra parameters that affect rebuild decisions.""" + target_cfg = config.targets[target_name] + params: dict[str, object] = {} + if target_cfg.width is not None: + params["width"] = target_cfg.width + if target_cfg.height is not None: + params["height"] = target_cfg.height + if target_cfg.reference_image is not None: + params["reference_image"] = target_cfg.reference_image + if target_cfg.control_images: + params["control_images"] = tuple(target_cfg.control_images) + return params + + +def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]: + """Collect all dependency names (inputs + reference_image + control_images).""" + target_cfg = config.targets[target_name] + deps: list[str] = list(target_cfg.inputs) + if target_cfg.reference_image is not None: + deps.append(target_cfg.reference_image) + deps.extend(target_cfg.control_images) + return deps + + +def _create_providers() -> dict[TargetType, Provider]: + """Create provider instances from environment variables.""" + providers: dict[TargetType, Provider] = {} + bfl_key = os.environ.get("BFL_API_KEY", "") + if bfl_key: + providers[TargetType.IMAGE] = ImageProvider(api_key=bfl_key) + mistral_key = os.environ.get("MISTRAL_API_KEY", "") + if mistral_key: + providers[TargetType.TEXT] = TextProvider(api_key=mistral_key) + return providers + + +async def _build_single_target( + target_name: str, + config: ProjectConfig, + project_dir: Path, + providers: dict[TargetType, Provider], +) -> None: + """Build a single target by dispatching to the appropriate provider.""" + target_cfg = config.targets[target_name] + target_type = infer_target_type(target_name) + model = resolve_model(target_name, target_cfg, config.defaults) + resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) + + provider = providers[target_type] + await provider.generate( + target_name=target_name, + target_config=target_cfg, + resolved_prompt=resolved_prompt, + resolved_model=model, + project_dir=project_dir, + ) + + +async def run_build( + config: ProjectConfig, + project_dir: Path, + target: str | None = None, +) -> BuildResult: + """Execute the build. + + If *target* is specified, only build that target and its transitive + dependencies. Otherwise build all targets. + + Execution proceeds in topological generations — each generation is a + set of independent targets that run concurrently via + :func:`asyncio.gather`. + """ + result = BuildResult() + providers = _create_providers() + + graph = build_graph(config, project_dir) + + if target is not None: + if target not in config.targets: + msg = f"Unknown target: '{target}'" + raise ValueError(msg) + graph = get_subgraph_for_target(graph, target) + + state = load_state(project_dir) + generations = get_build_order(graph) + target_names = set(config.targets) + + for generation in generations: + targets_in_gen = [n for n in generation if n in target_names] + + dirty_targets: list[str] = [] + for name in targets_in_gen: + if _should_skip_failed_dep(name, config, result): + result.failed[name] = "Dependency failed" + continue + + if _is_dirty(name, config, project_dir, state): + if not _has_provider(name, providers, result): + continue + dirty_targets.append(name) + else: + result.skipped.append(name) + + if not dirty_targets: + continue + + outcomes = await _build_generation( + dirty_targets, config, project_dir, providers + ) + + _process_outcomes(outcomes, config, project_dir, state, result) + save_state(state, project_dir) + + return result + + +def _should_skip_failed_dep( + target_name: str, config: ProjectConfig, result: BuildResult +) -> bool: + """Check if any dependency of a target has already failed.""" + return any(d in result.failed for d in _collect_all_deps(target_name, config)) + + +def _is_dirty( + target_name: str, + config: ProjectConfig, + project_dir: Path, + state: BuildState, +) -> bool: + """Check if a target needs rebuilding.""" + target_cfg = config.targets[target_name] + model = resolve_model(target_name, target_cfg, config.defaults) + resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) + dep_files = _collect_dep_files(target_name, config, project_dir) + extra = _collect_extra_params(target_name, config) + + return is_target_dirty( + target_name, + resolved_prompt=resolved_prompt, + model=model, + dep_files=dep_files, + extra_params=extra, + state=state, + project_dir=project_dir, + ) + + +def _has_provider( + target_name: str, + providers: dict[TargetType, Provider], + result: BuildResult, +) -> bool: + """Check that the required provider is available; record failure if not.""" + target_type = infer_target_type(target_name) + if target_type not in providers: + env_var = ( + "BFL_API_KEY" if target_type is TargetType.IMAGE else "MISTRAL_API_KEY" + ) + result.failed[target_name] = f"Missing {env_var} environment variable" + return False + return True + + +async def _build_generation( + dirty_targets: list[str], + config: ProjectConfig, + project_dir: Path, + providers: dict[TargetType, Provider], +) -> list[tuple[str, Exception | None]]: + """Build all dirty targets in a generation concurrently.""" + + async def _build_one(name: str) -> tuple[str, Exception | None]: + try: + await _build_single_target(name, config, project_dir, providers) + except Exception as exc: # noqa: BLE001 + return (name, exc) + return (name, None) + + return list(await asyncio.gather(*[_build_one(t) for t in dirty_targets])) + + +def _process_outcomes( + outcomes: list[tuple[str, Exception | None]], + config: ProjectConfig, + project_dir: Path, + state: BuildState, + result: BuildResult, +) -> None: + """Process build outcomes: record state for successes, log failures.""" + for name, error in outcomes: + if error is not None: + result.failed[name] = str(error) + typer.echo(f"FAIL: {name} -- {error}", err=True) + else: + target_cfg = config.targets[name] + model = resolve_model(name, target_cfg, config.defaults) + resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) + dep_files = _collect_dep_files(name, config, project_dir) + extra = _collect_extra_params(name, config) + + record_target_state( + name, + resolved_prompt=resolved_prompt, + model=model, + dep_files=dep_files, + extra_params=extra, + state=state, + project_dir=project_dir, + ) + result.built.append(name) + typer.echo(f" OK: {name}")