feat: replace infer_target_type with capability-based model selection and validation

This commit is contained in:
Konstantin Fickel 2026-02-15 08:34:12 +01:00
parent d15444bdb0
commit e7270a118d
Signed by: kfickel
GPG key ID: A793722F9933C1A5
3 changed files with 179 additions and 37 deletions

View file

@ -10,11 +10,11 @@ import yaml
from bulkgen.config import (
Defaults,
TargetConfig,
TargetType,
infer_target_type,
infer_required_capabilities,
load_config,
resolve_model,
)
from bulkgen.providers.models import Capability
class TestLoadConfig:
@ -88,34 +88,88 @@ class TestLoadConfig:
_ = load_config(config_path)
class TestInferTargetType:
"""Test target type inference from file extensions."""
class TestInferRequiredCapabilities:
"""Test capability inference from file extensions and target config."""
@pytest.mark.parametrize(
"name", ["photo.png", "photo.jpg", "photo.jpeg", "photo.webp"]
)
def test_plain_image(self) -> None:
target = TargetConfig(prompt="x")
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE}
)
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
def test_image_extensions(self, name: str) -> None:
assert infer_target_type(name) is TargetType.IMAGE
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps
@pytest.mark.parametrize("name", ["PHOTO.PNG", "PHOTO.JPG"])
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
def test_case_insensitive(self, name: str) -> None:
assert infer_target_type(name) is TargetType.IMAGE
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps
def test_image_with_reference_images(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.png"])
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
)
def test_image_with_control_images(self) -> None:
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
)
def test_image_with_both(self) -> None:
target = TargetConfig(
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
)
assert infer_required_capabilities("out.png", target) == frozenset(
{
Capability.TEXT_TO_IMAGE,
Capability.REFERENCE_IMAGES,
Capability.CONTROL_IMAGES,
}
)
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
def test_text_extensions(self, name: str) -> None:
assert infer_target_type(name) is TargetType.TEXT
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert caps == frozenset({Capability.TEXT_GENERATION})
def test_text_with_text_inputs(self) -> None:
target = TargetConfig(prompt="x", inputs=["data.txt"])
assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION}
)
def test_text_with_image_input(self) -> None:
target = TargetConfig(prompt="x", inputs=["photo.png"])
assert infer_required_capabilities("out.txt", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION}
)
def test_text_with_image_reference(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION}
)
def test_unsupported_extension_raises(self) -> None:
target = TargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_target_type("data.csv")
_ = infer_required_capabilities("data.csv", target)
def test_no_extension_raises(self) -> None:
target = TargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_target_type("Makefile")
_ = infer_required_capabilities("Makefile", target)
class TestResolveModel:
"""Test model resolution (explicit vs. default)."""
"""Test model resolution with capability validation."""
def test_explicit_model_wins(self) -> None:
target = TargetConfig(prompt="x", model="mistral-small-latest")
@ -138,3 +192,30 @@ class TestResolveModel:
target = TargetConfig(prompt="x", model="nonexistent-model")
with pytest.raises(ValueError, match="Unknown model"):
_ = resolve_model("out.txt", target, Defaults())
def test_explicit_model_missing_capability_raises(self) -> None:
# flux-dev does not support reference images
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
with pytest.raises(ValueError, match="lacks required capabilities"):
_ = resolve_model("out.png", target, Defaults())
def test_default_fallback_for_reference_images(self) -> None:
# Default flux-dev lacks reference_images, should fall back to a capable model
target = TargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults)
assert Capability.REFERENCE_IMAGES in result.capabilities
def test_default_fallback_for_vision(self) -> None:
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
target = TargetConfig(prompt="x", inputs=["photo.png"])
defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.txt", target, defaults)
assert Capability.VISION in result.capabilities
def test_default_preferred_when_capable(self) -> None:
# Default flux-2-pro already supports reference_images, should be used directly
target = TargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-2-pro")
result = resolve_model("out.png", target, defaults)
assert result.name == "flux-2-pro"