"""Pydantic models for bulkgen YAML configuration.""" from __future__ import annotations import enum from pathlib import Path from typing import TYPE_CHECKING, Self import yaml from pydantic import BaseModel, model_validator if TYPE_CHECKING: from bulkgen.providers.models import Capability, ModelInfo class TargetType(enum.Enum): """The kind of artifact a target produces.""" IMAGE = "image" TEXT = "text" IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"}) TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"}) class Defaults(BaseModel): """Default model names, applied when a target does not specify its own.""" text_model: str = "pixtral-large-latest" image_model: str = "flux-2-pro" class TargetConfig(BaseModel): """Configuration for a single build target.""" prompt: str model: str | None = None inputs: list[str] = [] reference_images: list[str] = [] control_images: list[str] = [] width: int | None = None height: int | None = None class ProjectConfig(BaseModel): """Top-level configuration parsed from ``.bulkgen.yaml``.""" defaults: Defaults = Defaults() targets: dict[str, TargetConfig] @model_validator(mode="after") def _validate_non_empty_targets(self) -> Self: if not self.targets: msg = "At least one target must be defined" raise ValueError(msg) return self def infer_required_capabilities( target_name: str, target: TargetConfig ) -> frozenset[Capability]: """Infer the capabilities a model must have based on filename and config. Raises :class:`ValueError` for unsupported file extensions. """ from bulkgen.providers.models import Capability suffix = Path(target_name).suffix.lower() caps: set[Capability] = set() if suffix in IMAGE_EXTENSIONS: caps.add(Capability.TEXT_TO_IMAGE) if target.reference_images: caps.add(Capability.REFERENCE_IMAGES) if target.control_images: caps.add(Capability.CONTROL_IMAGES) elif suffix in TEXT_EXTENSIONS: caps.add(Capability.TEXT_GENERATION) all_input_names = list(target.inputs) + list(target.reference_images) if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names): caps.add(Capability.VISION) else: msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'" raise ValueError(msg) return frozenset(caps) def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType: """Derive the target type from a set of required capabilities.""" from bulkgen.providers.models import Capability if Capability.TEXT_TO_IMAGE in capabilities: return TargetType.IMAGE return TargetType.TEXT def resolve_model( target_name: str, target: TargetConfig, defaults: Defaults ) -> ModelInfo: """Return the effective model for a target, validated against required capabilities. If the target specifies an explicit model, it is validated to have all required capabilities. Otherwise the type-appropriate default is tried first; if it lacks a required capability the first capable model of the same type is selected. Raises :class:`ValueError` if no suitable model can be found. """ from bulkgen.providers.models import ALL_MODELS required = infer_required_capabilities(target_name, target) target_type = target_type_from_capabilities(required) if target.model is not None: # Explicit model — look up and validate. for model in ALL_MODELS: if model.name == target.model: missing = required - frozenset(model.capabilities) if missing: names = ", ".join(sorted(missing)) msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}" raise ValueError(msg) return model msg = f"Unknown model '{target.model}' for target '{target_name}'" raise ValueError(msg) # No explicit model — try the default first, then fall back. default_name = ( defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model ) for model in ALL_MODELS: if model.name == default_name: if required <= frozenset(model.capabilities): return model break # Default lacks capabilities — find the first capable model of the same type. model_type = "image" if target_type is TargetType.IMAGE else "text" for model in ALL_MODELS: if model.type == model_type and required <= frozenset(model.capabilities): return model names = ", ".join(sorted(required)) msg = ( f"No model found for target '{target_name}' with required capabilities: {names}" ) raise ValueError(msg) def load_config(config_path: Path) -> ProjectConfig: """Load and validate a ``.bulkgen.yaml`` file.""" with config_path.open() as f: raw = yaml.safe_load(f) # pyright: ignore[reportAny] return ProjectConfig.model_validate(raw)