diff --git a/bulkgen/config.py b/bulkgen/config.py new file mode 100644 index 0000000..cf735dc --- /dev/null +++ b/bulkgen/config.py @@ -0,0 +1,82 @@ +"""Pydantic models for bulkgen YAML configuration.""" + +from __future__ import annotations + +import enum +from pathlib import Path +from typing import Self + +import yaml +from pydantic import BaseModel, model_validator + + +class TargetType(enum.Enum): + """The kind of artifact a target produces.""" + + IMAGE = "image" + TEXT = "text" + + +IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"}) +TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"}) + + +class Defaults(BaseModel): + """Default model names, applied when a target does not specify its own.""" + + text_model: str = "mistral-large-latest" + image_model: str = "flux-pro" + + +class TargetConfig(BaseModel): + """Configuration for a single build target.""" + + prompt: str + model: str | None = None + inputs: list[str] = [] + reference_image: str | None = None + control_images: list[str] = [] + width: int | None = None + height: int | None = None + + +class ProjectConfig(BaseModel): + """Top-level configuration parsed from ``.bulkgen.yaml``.""" + + defaults: Defaults = Defaults() + targets: dict[str, TargetConfig] + + @model_validator(mode="after") + def _validate_non_empty_targets(self) -> Self: + if not self.targets: + msg = "At least one target must be defined" + raise ValueError(msg) + return self + + +def infer_target_type(target_name: str) -> TargetType: + """Infer whether a target produces an image or text from its file extension.""" + suffix = Path(target_name).suffix.lower() + if suffix in IMAGE_EXTENSIONS: + return TargetType.IMAGE + if suffix in TEXT_EXTENSIONS: + return TargetType.TEXT + msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'" + raise ValueError(msg) + + +def resolve_model(target_name: str, target: TargetConfig, defaults: Defaults) -> str: + """Return the effective model for a target (explicit or default by type).""" + if target.model is not None: + return target.model + target_type = infer_target_type(target_name) + if target_type is TargetType.IMAGE: + return defaults.image_model + return defaults.text_model + + +def load_config(config_path: Path) -> ProjectConfig: + """Load and validate a ``.bulkgen.yaml`` file.""" + with config_path.open() as f: + raw = yaml.safe_load(f) + return ProjectConfig.model_validate(raw)