feat: replace infer_target_type with capability-based model selection and validation
This commit is contained in:
parent
d15444bdb0
commit
e7270a118d
3 changed files with 179 additions and 37 deletions
|
|
@ -12,8 +12,9 @@ from pathlib import Path
|
||||||
from bulkgen.config import (
|
from bulkgen.config import (
|
||||||
ProjectConfig,
|
ProjectConfig,
|
||||||
TargetType,
|
TargetType,
|
||||||
infer_target_type,
|
infer_required_capabilities,
|
||||||
resolve_model,
|
resolve_model,
|
||||||
|
target_type_from_capabilities,
|
||||||
)
|
)
|
||||||
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
|
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
|
||||||
from bulkgen.providers import Provider
|
from bulkgen.providers import Provider
|
||||||
|
|
@ -120,8 +121,9 @@ async def _build_single_target(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Build a single target by dispatching to the appropriate provider."""
|
"""Build a single target by dispatching to the appropriate provider."""
|
||||||
target_cfg = config.targets[target_name]
|
target_cfg = config.targets[target_name]
|
||||||
target_type = infer_target_type(target_name)
|
|
||||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||||
|
required = infer_required_capabilities(target_name, target_cfg)
|
||||||
|
target_type = target_type_from_capabilities(required)
|
||||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||||
|
|
||||||
provider = providers[target_type]
|
provider = providers[target_type]
|
||||||
|
|
@ -180,7 +182,7 @@ async def run_build(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if _is_dirty(name, config, project_dir, state):
|
if _is_dirty(name, config, project_dir, state):
|
||||||
if not _has_provider(name, providers, result, on_progress):
|
if not _has_provider(name, config, providers, result, on_progress):
|
||||||
continue
|
continue
|
||||||
dirty_targets.append(name)
|
dirty_targets.append(name)
|
||||||
else:
|
else:
|
||||||
|
|
@ -236,12 +238,15 @@ def _is_dirty(
|
||||||
|
|
||||||
def _has_provider(
|
def _has_provider(
|
||||||
target_name: str,
|
target_name: str,
|
||||||
|
config: ProjectConfig,
|
||||||
providers: dict[TargetType, Provider],
|
providers: dict[TargetType, Provider],
|
||||||
result: BuildResult,
|
result: BuildResult,
|
||||||
on_progress: ProgressCallback = _noop_callback,
|
on_progress: ProgressCallback = _noop_callback,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check that the required provider is available; record failure if not."""
|
"""Check that the required provider is available; record failure if not."""
|
||||||
target_type = infer_target_type(target_name)
|
target_cfg = config.targets[target_name]
|
||||||
|
required = infer_required_capabilities(target_name, target_cfg)
|
||||||
|
target_type = target_type_from_capabilities(required)
|
||||||
if target_type not in providers:
|
if target_type not in providers:
|
||||||
env_var = (
|
env_var = (
|
||||||
"BFL_API_KEY" if target_type is TargetType.IMAGE else "MISTRAL_API_KEY"
|
"BFL_API_KEY" if target_type is TargetType.IMAGE else "MISTRAL_API_KEY"
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import yaml
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from bulkgen.providers.models import ModelInfo
|
from bulkgen.providers.models import Capability, ModelInfo
|
||||||
|
|
||||||
|
|
||||||
class TargetType(enum.Enum):
|
class TargetType(enum.Enum):
|
||||||
|
|
@ -57,39 +57,95 @@ class ProjectConfig(BaseModel):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def infer_target_type(target_name: str) -> TargetType:
|
def infer_required_capabilities(
|
||||||
"""Infer whether a target produces an image or text from its file extension."""
|
target_name: str, target: TargetConfig
|
||||||
|
) -> frozenset[Capability]:
|
||||||
|
"""Infer the capabilities a model must have based on filename and config.
|
||||||
|
|
||||||
|
Raises :class:`ValueError` for unsupported file extensions.
|
||||||
|
"""
|
||||||
|
from bulkgen.providers.models import Capability
|
||||||
|
|
||||||
suffix = Path(target_name).suffix.lower()
|
suffix = Path(target_name).suffix.lower()
|
||||||
|
caps: set[Capability] = set()
|
||||||
|
|
||||||
if suffix in IMAGE_EXTENSIONS:
|
if suffix in IMAGE_EXTENSIONS:
|
||||||
return TargetType.IMAGE
|
caps.add(Capability.TEXT_TO_IMAGE)
|
||||||
if suffix in TEXT_EXTENSIONS:
|
if target.reference_images:
|
||||||
return TargetType.TEXT
|
caps.add(Capability.REFERENCE_IMAGES)
|
||||||
|
if target.control_images:
|
||||||
|
caps.add(Capability.CONTROL_IMAGES)
|
||||||
|
elif suffix in TEXT_EXTENSIONS:
|
||||||
|
caps.add(Capability.TEXT_GENERATION)
|
||||||
|
all_input_names = list(target.inputs) + list(target.reference_images)
|
||||||
|
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
|
||||||
|
caps.add(Capability.VISION)
|
||||||
|
else:
|
||||||
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
|
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return frozenset(caps)
|
||||||
|
|
||||||
|
|
||||||
|
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
|
||||||
|
"""Derive the target type from a set of required capabilities."""
|
||||||
|
from bulkgen.providers.models import Capability
|
||||||
|
|
||||||
|
if Capability.TEXT_TO_IMAGE in capabilities:
|
||||||
|
return TargetType.IMAGE
|
||||||
|
return TargetType.TEXT
|
||||||
|
|
||||||
|
|
||||||
def resolve_model(
|
def resolve_model(
|
||||||
target_name: str, target: TargetConfig, defaults: Defaults
|
target_name: str, target: TargetConfig, defaults: Defaults
|
||||||
) -> ModelInfo:
|
) -> ModelInfo:
|
||||||
"""Return the effective model for a target (explicit or default by type).
|
"""Return the effective model for a target, validated against required capabilities.
|
||||||
|
|
||||||
Raises :class:`ValueError` if the resolved model name is not in the registry.
|
If the target specifies an explicit model, it is validated to have all
|
||||||
|
required capabilities. Otherwise the type-appropriate default is tried
|
||||||
|
first; if it lacks a required capability the first capable model of the
|
||||||
|
same type is selected.
|
||||||
|
|
||||||
|
Raises :class:`ValueError` if no suitable model can be found.
|
||||||
"""
|
"""
|
||||||
from bulkgen.providers.models import ALL_MODELS
|
from bulkgen.providers.models import ALL_MODELS
|
||||||
|
|
||||||
|
required = infer_required_capabilities(target_name, target)
|
||||||
|
target_type = target_type_from_capabilities(required)
|
||||||
|
|
||||||
if target.model is not None:
|
if target.model is not None:
|
||||||
model_name = target.model
|
# Explicit model — look up and validate.
|
||||||
else:
|
for model in ALL_MODELS:
|
||||||
target_type = infer_target_type(target_name)
|
if model.name == target.model:
|
||||||
model_name = (
|
missing = required - frozenset(model.capabilities)
|
||||||
defaults.image_model
|
if missing:
|
||||||
if target_type is TargetType.IMAGE
|
names = ", ".join(sorted(missing))
|
||||||
else defaults.text_model
|
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return model
|
||||||
|
msg = f"Unknown model '{target.model}' for target '{target_name}'"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# No explicit model — try the default first, then fall back.
|
||||||
|
default_name = (
|
||||||
|
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
|
||||||
)
|
)
|
||||||
for model in ALL_MODELS:
|
for model in ALL_MODELS:
|
||||||
if model.name == model_name:
|
if model.name == default_name:
|
||||||
|
if required <= frozenset(model.capabilities):
|
||||||
return model
|
return model
|
||||||
msg = f"Unknown model '{model_name}' for target '{target_name}'"
|
break
|
||||||
|
|
||||||
|
# Default lacks capabilities — find the first capable model of the same type.
|
||||||
|
model_type = "image" if target_type is TargetType.IMAGE else "text"
|
||||||
|
for model in ALL_MODELS:
|
||||||
|
if model.type == model_type and required <= frozenset(model.capabilities):
|
||||||
|
return model
|
||||||
|
|
||||||
|
names = ", ".join(sorted(required))
|
||||||
|
msg = (
|
||||||
|
f"No model found for target '{target_name}' with required capabilities: {names}"
|
||||||
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,11 @@ import yaml
|
||||||
from bulkgen.config import (
|
from bulkgen.config import (
|
||||||
Defaults,
|
Defaults,
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
TargetType,
|
infer_required_capabilities,
|
||||||
infer_target_type,
|
|
||||||
load_config,
|
load_config,
|
||||||
resolve_model,
|
resolve_model,
|
||||||
)
|
)
|
||||||
|
from bulkgen.providers.models import Capability
|
||||||
|
|
||||||
|
|
||||||
class TestLoadConfig:
|
class TestLoadConfig:
|
||||||
|
|
@ -88,34 +88,88 @@ class TestLoadConfig:
|
||||||
_ = load_config(config_path)
|
_ = load_config(config_path)
|
||||||
|
|
||||||
|
|
||||||
class TestInferTargetType:
|
class TestInferRequiredCapabilities:
|
||||||
"""Test target type inference from file extensions."""
|
"""Test capability inference from file extensions and target config."""
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_plain_image(self) -> None:
|
||||||
"name", ["photo.png", "photo.jpg", "photo.jpeg", "photo.webp"]
|
target = TargetConfig(prompt="x")
|
||||||
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
{Capability.TEXT_TO_IMAGE}
|
||||||
)
|
)
|
||||||
def test_image_extensions(self, name: str) -> None:
|
|
||||||
assert infer_target_type(name) is TargetType.IMAGE
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["PHOTO.PNG", "PHOTO.JPG"])
|
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
|
||||||
|
def test_image_extensions(self, name: str) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
caps = infer_required_capabilities(name, target)
|
||||||
|
assert Capability.TEXT_TO_IMAGE in caps
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
|
||||||
def test_case_insensitive(self, name: str) -> None:
|
def test_case_insensitive(self, name: str) -> None:
|
||||||
assert infer_target_type(name) is TargetType.IMAGE
|
target = TargetConfig(prompt="x")
|
||||||
|
caps = infer_required_capabilities(name, target)
|
||||||
|
assert Capability.TEXT_TO_IMAGE in caps
|
||||||
|
|
||||||
|
def test_image_with_reference_images(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", reference_images=["ref.png"])
|
||||||
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_image_with_control_images(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
|
||||||
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_image_with_both(self) -> None:
|
||||||
|
target = TargetConfig(
|
||||||
|
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
|
||||||
|
)
|
||||||
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
{
|
||||||
|
Capability.TEXT_TO_IMAGE,
|
||||||
|
Capability.REFERENCE_IMAGES,
|
||||||
|
Capability.CONTROL_IMAGES,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
||||||
def test_text_extensions(self, name: str) -> None:
|
def test_text_extensions(self, name: str) -> None:
|
||||||
assert infer_target_type(name) is TargetType.TEXT
|
target = TargetConfig(prompt="x")
|
||||||
|
caps = infer_required_capabilities(name, target)
|
||||||
|
assert caps == frozenset({Capability.TEXT_GENERATION})
|
||||||
|
|
||||||
|
def test_text_with_text_inputs(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", inputs=["data.txt"])
|
||||||
|
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||||
|
{Capability.TEXT_GENERATION}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_text_with_image_input(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||||
|
assert infer_required_capabilities("out.txt", target) == frozenset(
|
||||||
|
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_text_with_image_reference(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
|
||||||
|
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||||
|
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||||
|
)
|
||||||
|
|
||||||
def test_unsupported_extension_raises(self) -> None:
|
def test_unsupported_extension_raises(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
with pytest.raises(ValueError, match="unsupported extension"):
|
with pytest.raises(ValueError, match="unsupported extension"):
|
||||||
_ = infer_target_type("data.csv")
|
_ = infer_required_capabilities("data.csv", target)
|
||||||
|
|
||||||
def test_no_extension_raises(self) -> None:
|
def test_no_extension_raises(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
with pytest.raises(ValueError, match="unsupported extension"):
|
with pytest.raises(ValueError, match="unsupported extension"):
|
||||||
_ = infer_target_type("Makefile")
|
_ = infer_required_capabilities("Makefile", target)
|
||||||
|
|
||||||
|
|
||||||
class TestResolveModel:
|
class TestResolveModel:
|
||||||
"""Test model resolution (explicit vs. default)."""
|
"""Test model resolution with capability validation."""
|
||||||
|
|
||||||
def test_explicit_model_wins(self) -> None:
|
def test_explicit_model_wins(self) -> None:
|
||||||
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
||||||
|
|
@ -138,3 +192,30 @@ class TestResolveModel:
|
||||||
target = TargetConfig(prompt="x", model="nonexistent-model")
|
target = TargetConfig(prompt="x", model="nonexistent-model")
|
||||||
with pytest.raises(ValueError, match="Unknown model"):
|
with pytest.raises(ValueError, match="Unknown model"):
|
||||||
_ = resolve_model("out.txt", target, Defaults())
|
_ = resolve_model("out.txt", target, Defaults())
|
||||||
|
|
||||||
|
def test_explicit_model_missing_capability_raises(self) -> None:
|
||||||
|
# flux-dev does not support reference images
|
||||||
|
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
|
||||||
|
with pytest.raises(ValueError, match="lacks required capabilities"):
|
||||||
|
_ = resolve_model("out.png", target, Defaults())
|
||||||
|
|
||||||
|
def test_default_fallback_for_reference_images(self) -> None:
|
||||||
|
# Default flux-dev lacks reference_images, should fall back to a capable model
|
||||||
|
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||||
|
defaults = Defaults(image_model="flux-dev")
|
||||||
|
result = resolve_model("out.png", target, defaults)
|
||||||
|
assert Capability.REFERENCE_IMAGES in result.capabilities
|
||||||
|
|
||||||
|
def test_default_fallback_for_vision(self) -> None:
|
||||||
|
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
|
||||||
|
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||||
|
defaults = Defaults(text_model="mistral-large-latest")
|
||||||
|
result = resolve_model("out.txt", target, defaults)
|
||||||
|
assert Capability.VISION in result.capabilities
|
||||||
|
|
||||||
|
def test_default_preferred_when_capable(self) -> None:
|
||||||
|
# Default flux-2-pro already supports reference_images, should be used directly
|
||||||
|
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||||
|
defaults = Defaults(image_model="flux-2-pro")
|
||||||
|
result = resolve_model("out.png", target, defaults)
|
||||||
|
assert result.name == "flux-2-pro"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue