hokusai/bulkgen/config.py
Konstantin Fickel d0dac5b1bf
Some checks failed
Continuous Integration / Build Package (push) Successful in 30s
Continuous Integration / Lint, Check & Test (push) Failing after 38s
refactor: move model definitions into providers and extract resolve module
- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider
- Add get_provided_models() abstract method to Provider base class
- Move model lists from models.py into each provider's get_provided_models()
- Add providers/registry.py to aggregate models from all providers
- Extract infer_required_capabilities and resolve_model from config.py to resolve.py
- Update tests to use new names and import paths
2026-02-15 11:03:57 +01:00

71 lines
1.9 KiB
Python

"""Pydantic models for bulkgen YAML configuration."""
from __future__ import annotations
import enum
from pathlib import Path
from typing import Self
import yaml
from pydantic import BaseModel, model_validator
from bulkgen.providers.models import Capability
IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"})
TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"})
class TargetType(enum.Enum):
"""The kind of artifact a target produces."""
IMAGE = "image"
TEXT = "text"
class Defaults(BaseModel):
"""Default model names, applied when a target does not specify its own."""
text_model: str = "pixtral-large-latest"
image_model: str = "flux-2-pro"
class TargetConfig(BaseModel):
"""Configuration for a single build target."""
prompt: str
model: str | None = None
inputs: list[str] = []
reference_images: list[str] = []
control_images: list[str] = []
width: int | None = None
height: int | None = None
class ProjectConfig(BaseModel):
"""Top-level configuration parsed from ``<name>.bulkgen.yaml``."""
defaults: Defaults = Defaults()
targets: dict[str, TargetConfig]
@model_validator(mode="after")
def _validate_non_empty_targets(self) -> Self:
if not self.targets:
msg = "At least one target must be defined"
raise ValueError(msg)
return self
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
"""Derive the target type from a set of required capabilities."""
from bulkgen.providers.models import Capability
if Capability.TEXT_TO_IMAGE in capabilities:
return TargetType.IMAGE
return TargetType.TEXT
def load_config(config_path: Path) -> ProjectConfig:
"""Load and validate a ``.bulkgen.yaml`` file."""
with config_path.open() as f:
raw = yaml.safe_load(f) # pyright: ignore[reportAny]
return ProjectConfig.model_validate(raw)