feat: replace infer_target_type with capability-based model selection and validation
This commit is contained in:
parent
d15444bdb0
commit
e7270a118d
3 changed files with 179 additions and 37 deletions
|
|
@ -10,7 +10,7 @@ import yaml
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.providers.models import Capability, ModelInfo
|
||||
|
||||
|
||||
class TargetType(enum.Enum):
|
||||
|
|
@ -57,39 +57,95 @@ class ProjectConfig(BaseModel):
|
|||
return self
|
||||
|
||||
|
||||
def infer_target_type(target_name: str) -> TargetType:
|
||||
"""Infer whether a target produces an image or text from its file extension."""
|
||||
def infer_required_capabilities(
|
||||
target_name: str, target: TargetConfig
|
||||
) -> frozenset[Capability]:
|
||||
"""Infer the capabilities a model must have based on filename and config.
|
||||
|
||||
Raises :class:`ValueError` for unsupported file extensions.
|
||||
"""
|
||||
from bulkgen.providers.models import Capability
|
||||
|
||||
suffix = Path(target_name).suffix.lower()
|
||||
caps: set[Capability] = set()
|
||||
|
||||
if suffix in IMAGE_EXTENSIONS:
|
||||
caps.add(Capability.TEXT_TO_IMAGE)
|
||||
if target.reference_images:
|
||||
caps.add(Capability.REFERENCE_IMAGES)
|
||||
if target.control_images:
|
||||
caps.add(Capability.CONTROL_IMAGES)
|
||||
elif suffix in TEXT_EXTENSIONS:
|
||||
caps.add(Capability.TEXT_GENERATION)
|
||||
all_input_names = list(target.inputs) + list(target.reference_images)
|
||||
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
|
||||
caps.add(Capability.VISION)
|
||||
else:
|
||||
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
return frozenset(caps)
|
||||
|
||||
|
||||
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
|
||||
"""Derive the target type from a set of required capabilities."""
|
||||
from bulkgen.providers.models import Capability
|
||||
|
||||
if Capability.TEXT_TO_IMAGE in capabilities:
|
||||
return TargetType.IMAGE
|
||||
if suffix in TEXT_EXTENSIONS:
|
||||
return TargetType.TEXT
|
||||
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
|
||||
raise ValueError(msg)
|
||||
return TargetType.TEXT
|
||||
|
||||
|
||||
def resolve_model(
|
||||
target_name: str, target: TargetConfig, defaults: Defaults
|
||||
) -> ModelInfo:
|
||||
"""Return the effective model for a target (explicit or default by type).
|
||||
"""Return the effective model for a target, validated against required capabilities.
|
||||
|
||||
Raises :class:`ValueError` if the resolved model name is not in the registry.
|
||||
If the target specifies an explicit model, it is validated to have all
|
||||
required capabilities. Otherwise the type-appropriate default is tried
|
||||
first; if it lacks a required capability the first capable model of the
|
||||
same type is selected.
|
||||
|
||||
Raises :class:`ValueError` if no suitable model can be found.
|
||||
"""
|
||||
from bulkgen.providers.models import ALL_MODELS
|
||||
|
||||
required = infer_required_capabilities(target_name, target)
|
||||
target_type = target_type_from_capabilities(required)
|
||||
|
||||
if target.model is not None:
|
||||
model_name = target.model
|
||||
else:
|
||||
target_type = infer_target_type(target_name)
|
||||
model_name = (
|
||||
defaults.image_model
|
||||
if target_type is TargetType.IMAGE
|
||||
else defaults.text_model
|
||||
)
|
||||
# Explicit model — look up and validate.
|
||||
for model in ALL_MODELS:
|
||||
if model.name == target.model:
|
||||
missing = required - frozenset(model.capabilities)
|
||||
if missing:
|
||||
names = ", ".join(sorted(missing))
|
||||
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
|
||||
raise ValueError(msg)
|
||||
return model
|
||||
msg = f"Unknown model '{target.model}' for target '{target_name}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
# No explicit model — try the default first, then fall back.
|
||||
default_name = (
|
||||
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
|
||||
)
|
||||
for model in ALL_MODELS:
|
||||
if model.name == model_name:
|
||||
if model.name == default_name:
|
||||
if required <= frozenset(model.capabilities):
|
||||
return model
|
||||
break
|
||||
|
||||
# Default lacks capabilities — find the first capable model of the same type.
|
||||
model_type = "image" if target_type is TargetType.IMAGE else "text"
|
||||
for model in ALL_MODELS:
|
||||
if model.type == model_type and required <= frozenset(model.capabilities):
|
||||
return model
|
||||
msg = f"Unknown model '{model_name}' for target '{target_name}'"
|
||||
|
||||
names = ", ".join(sorted(required))
|
||||
msg = (
|
||||
f"No model found for target '{target_name}' with required capabilities: {names}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue