hokusai/bulkgen/resolve.py
Konstantin Fickel d0dac5b1bf
Some checks failed
Continuous Integration / Build Package (push) Successful in 30s
Continuous Integration / Lint, Check & Test (push) Failing after 38s
refactor: move model definitions into providers and extract resolve module
- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider
- Add get_provided_models() abstract method to Provider base class
- Move model lists from models.py into each provider's get_provided_models()
- Add providers/registry.py to aggregate models from all providers
- Extract infer_required_capabilities and resolve_model from config.py to resolve.py
- Update tests to use new names and import paths
2026-02-15 11:03:57 +01:00

95 lines
3.4 KiB
Python

from __future__ import annotations
from pathlib import Path
from bulkgen.config import (
IMAGE_EXTENSIONS,
TEXT_EXTENSIONS,
Defaults,
TargetConfig,
TargetType,
target_type_from_capabilities,
)
from bulkgen.providers.models import Capability, ModelInfo
def infer_required_capabilities(
target_name: str, target: TargetConfig
) -> frozenset[Capability]:
"""Infer the capabilities a model must have based on filename and config.
Raises :class:`ValueError` for unsupported file extensions.
"""
suffix = Path(target_name).suffix.lower()
caps: set[Capability] = set()
if suffix in IMAGE_EXTENSIONS:
caps.add(Capability.TEXT_TO_IMAGE)
if target.reference_images:
caps.add(Capability.REFERENCE_IMAGES)
if target.control_images:
caps.add(Capability.CONTROL_IMAGES)
elif suffix in TEXT_EXTENSIONS:
caps.add(Capability.TEXT_GENERATION)
all_input_names = list(target.inputs) + list(target.reference_images)
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
caps.add(Capability.VISION)
else:
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
raise ValueError(msg)
return frozenset(caps)
def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target, validated against required capabilities.
If the target specifies an explicit model, it is validated to have all
required capabilities. Otherwise the type-appropriate default is tried
first; if it lacks a required capability the first capable model of the
same type is selected.
Raises :class:`ValueError` if no suitable model can be found.
"""
from bulkgen.providers.registry import get_all_models
all_models = get_all_models()
required = infer_required_capabilities(target_name, target)
target_type = target_type_from_capabilities(required)
if target.model is not None:
# Explicit model — look up and validate.
for model in all_models:
if model.name == target.model:
missing = required - frozenset(model.capabilities)
if missing:
names = ", ".join(sorted(missing))
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
raise ValueError(msg)
return model
msg = f"Unknown model '{target.model}' for target '{target_name}'"
raise ValueError(msg)
# No explicit model — try the default first, then fall back.
default_name = (
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
)
for model in all_models:
if model.name == default_name:
if required <= frozenset(model.capabilities):
return model
break
# Default lacks capabilities — find the first capable model of the same type.
model_type = "image" if target_type is TargetType.IMAGE else "text"
for model in all_models:
if model.type == model_type and required <= frozenset(model.capabilities):
return model
names = ", ".join(sorted(required))
msg = (
f"No model found for target '{target_name}' with required capabilities: {names}"
)
raise ValueError(msg)