feat: replace infer_target_type with capability-based model selection and validation

This commit is contained in:
Konstantin Fickel 2026-02-15 08:34:12 +01:00
parent d15444bdb0
commit e7270a118d
Signed by: kfickel
GPG key ID: A793722F9933C1A5
3 changed files with 179 additions and 37 deletions

View file

@ -12,8 +12,9 @@ from pathlib import Path
from bulkgen.config import (
ProjectConfig,
TargetType,
infer_target_type,
infer_required_capabilities,
resolve_model,
target_type_from_capabilities,
)
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
from bulkgen.providers import Provider
@ -120,8 +121,9 @@ async def _build_single_target(
) -> None:
"""Build a single target by dispatching to the appropriate provider."""
target_cfg = config.targets[target_name]
target_type = infer_target_type(target_name)
model_info = resolve_model(target_name, target_cfg, config.defaults)
required = infer_required_capabilities(target_name, target_cfg)
target_type = target_type_from_capabilities(required)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
provider = providers[target_type]
@ -180,7 +182,7 @@ async def run_build(
continue
if _is_dirty(name, config, project_dir, state):
if not _has_provider(name, providers, result, on_progress):
if not _has_provider(name, config, providers, result, on_progress):
continue
dirty_targets.append(name)
else:
@ -236,12 +238,15 @@ def _is_dirty(
def _has_provider(
target_name: str,
config: ProjectConfig,
providers: dict[TargetType, Provider],
result: BuildResult,
on_progress: ProgressCallback = _noop_callback,
) -> bool:
"""Check that the required provider is available; record failure if not."""
target_type = infer_target_type(target_name)
target_cfg = config.targets[target_name]
required = infer_required_capabilities(target_name, target_cfg)
target_type = target_type_from_capabilities(required)
if target_type not in providers:
env_var = (
"BFL_API_KEY" if target_type is TargetType.IMAGE else "MISTRAL_API_KEY"

View file

@ -10,7 +10,7 @@ import yaml
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from bulkgen.providers.models import ModelInfo
from bulkgen.providers.models import Capability, ModelInfo
class TargetType(enum.Enum):
@ -57,39 +57,95 @@ class ProjectConfig(BaseModel):
return self
def infer_target_type(target_name: str) -> TargetType:
"""Infer whether a target produces an image or text from its file extension."""
def infer_required_capabilities(
target_name: str, target: TargetConfig
) -> frozenset[Capability]:
"""Infer the capabilities a model must have based on filename and config.
Raises :class:`ValueError` for unsupported file extensions.
"""
from bulkgen.providers.models import Capability
suffix = Path(target_name).suffix.lower()
caps: set[Capability] = set()
if suffix in IMAGE_EXTENSIONS:
caps.add(Capability.TEXT_TO_IMAGE)
if target.reference_images:
caps.add(Capability.REFERENCE_IMAGES)
if target.control_images:
caps.add(Capability.CONTROL_IMAGES)
elif suffix in TEXT_EXTENSIONS:
caps.add(Capability.TEXT_GENERATION)
all_input_names = list(target.inputs) + list(target.reference_images)
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
caps.add(Capability.VISION)
else:
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
raise ValueError(msg)
return frozenset(caps)
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
"""Derive the target type from a set of required capabilities."""
from bulkgen.providers.models import Capability
if Capability.TEXT_TO_IMAGE in capabilities:
return TargetType.IMAGE
if suffix in TEXT_EXTENSIONS:
return TargetType.TEXT
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
raise ValueError(msg)
return TargetType.TEXT
def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target (explicit or default by type).
"""Return the effective model for a target, validated against required capabilities.
Raises :class:`ValueError` if the resolved model name is not in the registry.
If the target specifies an explicit model, it is validated to have all
required capabilities. Otherwise the type-appropriate default is tried
first; if it lacks a required capability the first capable model of the
same type is selected.
Raises :class:`ValueError` if no suitable model can be found.
"""
from bulkgen.providers.models import ALL_MODELS
required = infer_required_capabilities(target_name, target)
target_type = target_type_from_capabilities(required)
if target.model is not None:
model_name = target.model
else:
target_type = infer_target_type(target_name)
model_name = (
defaults.image_model
if target_type is TargetType.IMAGE
else defaults.text_model
)
# Explicit model — look up and validate.
for model in ALL_MODELS:
if model.name == target.model:
missing = required - frozenset(model.capabilities)
if missing:
names = ", ".join(sorted(missing))
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
raise ValueError(msg)
return model
msg = f"Unknown model '{target.model}' for target '{target_name}'"
raise ValueError(msg)
# No explicit model — try the default first, then fall back.
default_name = (
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
)
for model in ALL_MODELS:
if model.name == model_name:
if model.name == default_name:
if required <= frozenset(model.capabilities):
return model
break
# Default lacks capabilities — find the first capable model of the same type.
model_type = "image" if target_type is TargetType.IMAGE else "text"
for model in ALL_MODELS:
if model.type == model_type and required <= frozenset(model.capabilities):
return model
msg = f"Unknown model '{model_name}' for target '{target_name}'"
names = ", ".join(sorted(required))
msg = (
f"No model found for target '{target_name}' with required capabilities: {names}"
)
raise ValueError(msg)