156 lines
5.1 KiB
Python
156 lines
5.1 KiB
Python
"""Pydantic models for bulkgen YAML configuration."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import enum
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Self
|
|
|
|
import yaml
|
|
from pydantic import BaseModel, model_validator
|
|
|
|
if TYPE_CHECKING:
|
|
from bulkgen.providers.models import Capability, ModelInfo
|
|
|
|
|
|
class TargetType(enum.Enum):
|
|
"""The kind of artifact a target produces."""
|
|
|
|
IMAGE = "image"
|
|
TEXT = "text"
|
|
|
|
|
|
IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"})
|
|
TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"})
|
|
|
|
|
|
class Defaults(BaseModel):
|
|
"""Default model names, applied when a target does not specify its own."""
|
|
|
|
text_model: str = "pixtral-large-latest"
|
|
image_model: str = "flux-2-pro"
|
|
|
|
|
|
class TargetConfig(BaseModel):
|
|
"""Configuration for a single build target."""
|
|
|
|
prompt: str
|
|
model: str | None = None
|
|
inputs: list[str] = []
|
|
reference_images: list[str] = []
|
|
control_images: list[str] = []
|
|
width: int | None = None
|
|
height: int | None = None
|
|
|
|
|
|
class ProjectConfig(BaseModel):
|
|
"""Top-level configuration parsed from ``<name>.bulkgen.yaml``."""
|
|
|
|
defaults: Defaults = Defaults()
|
|
targets: dict[str, TargetConfig]
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_non_empty_targets(self) -> Self:
|
|
if not self.targets:
|
|
msg = "At least one target must be defined"
|
|
raise ValueError(msg)
|
|
return self
|
|
|
|
|
|
def infer_required_capabilities(
|
|
target_name: str, target: TargetConfig
|
|
) -> frozenset[Capability]:
|
|
"""Infer the capabilities a model must have based on filename and config.
|
|
|
|
Raises :class:`ValueError` for unsupported file extensions.
|
|
"""
|
|
from bulkgen.providers.models import Capability
|
|
|
|
suffix = Path(target_name).suffix.lower()
|
|
caps: set[Capability] = set()
|
|
|
|
if suffix in IMAGE_EXTENSIONS:
|
|
caps.add(Capability.TEXT_TO_IMAGE)
|
|
if target.reference_images:
|
|
caps.add(Capability.REFERENCE_IMAGES)
|
|
if target.control_images:
|
|
caps.add(Capability.CONTROL_IMAGES)
|
|
elif suffix in TEXT_EXTENSIONS:
|
|
caps.add(Capability.TEXT_GENERATION)
|
|
all_input_names = list(target.inputs) + list(target.reference_images)
|
|
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
|
|
caps.add(Capability.VISION)
|
|
else:
|
|
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
|
|
raise ValueError(msg)
|
|
|
|
return frozenset(caps)
|
|
|
|
|
|
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
|
|
"""Derive the target type from a set of required capabilities."""
|
|
from bulkgen.providers.models import Capability
|
|
|
|
if Capability.TEXT_TO_IMAGE in capabilities:
|
|
return TargetType.IMAGE
|
|
return TargetType.TEXT
|
|
|
|
|
|
def resolve_model(
|
|
target_name: str, target: TargetConfig, defaults: Defaults
|
|
) -> ModelInfo:
|
|
"""Return the effective model for a target, validated against required capabilities.
|
|
|
|
If the target specifies an explicit model, it is validated to have all
|
|
required capabilities. Otherwise the type-appropriate default is tried
|
|
first; if it lacks a required capability the first capable model of the
|
|
same type is selected.
|
|
|
|
Raises :class:`ValueError` if no suitable model can be found.
|
|
"""
|
|
from bulkgen.providers.models import ALL_MODELS
|
|
|
|
required = infer_required_capabilities(target_name, target)
|
|
target_type = target_type_from_capabilities(required)
|
|
|
|
if target.model is not None:
|
|
# Explicit model — look up and validate.
|
|
for model in ALL_MODELS:
|
|
if model.name == target.model:
|
|
missing = required - frozenset(model.capabilities)
|
|
if missing:
|
|
names = ", ".join(sorted(missing))
|
|
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
|
|
raise ValueError(msg)
|
|
return model
|
|
msg = f"Unknown model '{target.model}' for target '{target_name}'"
|
|
raise ValueError(msg)
|
|
|
|
# No explicit model — try the default first, then fall back.
|
|
default_name = (
|
|
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
|
|
)
|
|
for model in ALL_MODELS:
|
|
if model.name == default_name:
|
|
if required <= frozenset(model.capabilities):
|
|
return model
|
|
break
|
|
|
|
# Default lacks capabilities — find the first capable model of the same type.
|
|
model_type = "image" if target_type is TargetType.IMAGE else "text"
|
|
for model in ALL_MODELS:
|
|
if model.type == model_type and required <= frozenset(model.capabilities):
|
|
return model
|
|
|
|
names = ", ".join(sorted(required))
|
|
msg = (
|
|
f"No model found for target '{target_name}' with required capabilities: {names}"
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
|
|
def load_config(config_path: Path) -> ProjectConfig:
|
|
"""Load and validate a ``.bulkgen.yaml`` file."""
|
|
with config_path.open() as f:
|
|
raw = yaml.safe_load(f) # pyright: ignore[reportAny]
|
|
return ProjectConfig.model_validate(raw)
|