"""Pydantic models for bulkgen YAML configuration.""" from __future__ import annotations import enum from pathlib import Path from typing import TYPE_CHECKING, Self import yaml from pydantic import BaseModel, model_validator if TYPE_CHECKING: from bulkgen.providers.models import ModelInfo class TargetType(enum.Enum): """The kind of artifact a target produces.""" IMAGE = "image" TEXT = "text" IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"}) TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"}) class Defaults(BaseModel): """Default model names, applied when a target does not specify its own.""" text_model: str = "pixtral-large-latest" image_model: str = "flux-2-pro" class TargetConfig(BaseModel): """Configuration for a single build target.""" prompt: str model: str | None = None inputs: list[str] = [] reference_images: list[str] = [] control_images: list[str] = [] width: int | None = None height: int | None = None class ProjectConfig(BaseModel): """Top-level configuration parsed from ``.bulkgen.yaml``.""" defaults: Defaults = Defaults() targets: dict[str, TargetConfig] @model_validator(mode="after") def _validate_non_empty_targets(self) -> Self: if not self.targets: msg = "At least one target must be defined" raise ValueError(msg) return self def infer_target_type(target_name: str) -> TargetType: """Infer whether a target produces an image or text from its file extension.""" suffix = Path(target_name).suffix.lower() if suffix in IMAGE_EXTENSIONS: return TargetType.IMAGE if suffix in TEXT_EXTENSIONS: return TargetType.TEXT msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'" raise ValueError(msg) def resolve_model( target_name: str, target: TargetConfig, defaults: Defaults ) -> ModelInfo: """Return the effective model for a target (explicit or default by type). Raises :class:`ValueError` if the resolved model name is not in the registry. """ from bulkgen.providers.models import ALL_MODELS if target.model is not None: model_name = target.model else: target_type = infer_target_type(target_name) model_name = ( defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model ) for model in ALL_MODELS: if model.name == model_name: return model msg = f"Unknown model '{model_name}' for target '{target_name}'" raise ValueError(msg) def load_config(config_path: Path) -> ProjectConfig: """Load and validate a ``.bulkgen.yaml`` file.""" with config_path.open() as f: raw = yaml.safe_load(f) # pyright: ignore[reportAny] return ProjectConfig.model_validate(raw)