refactor: move model definitions into providers and extract resolve module
- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider - Add get_provided_models() abstract method to Provider base class - Move model lists from models.py into each provider's get_provided_models() - Add providers/registry.py to aggregate models from all providers - Extract infer_required_capabilities and resolve_model from config.py to resolve.py - Update tests to use new names and import paths
This commit is contained in:
parent
dc6a75f5c4
commit
d0dac5b1bf
13 changed files with 432 additions and 390 deletions
|
|
@ -27,6 +27,11 @@ WriteConfig = Callable[[dict[str, object]], ProjectConfig]
|
|||
class FakeProvider(Provider):
|
||||
"""A provider that writes a marker file instead of calling an API."""
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_provided_models() -> list[ModelInfo]:
|
||||
return []
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
|
|
@ -43,6 +48,11 @@ class FakeProvider(Provider):
|
|||
class FailingProvider(Provider):
|
||||
"""A provider that always raises."""
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_provided_models() -> list[ModelInfo]:
|
||||
return []
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -7,14 +7,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
import yaml
|
||||
|
||||
from bulkgen.config import (
|
||||
Defaults,
|
||||
TargetConfig,
|
||||
infer_required_capabilities,
|
||||
load_config,
|
||||
resolve_model,
|
||||
)
|
||||
from bulkgen.providers.models import Capability
|
||||
from bulkgen.config import load_config
|
||||
|
||||
|
||||
class TestLoadConfig:
|
||||
|
|
@ -86,136 +79,3 @@ class TestLoadConfig:
|
|||
|
||||
with pytest.raises(Exception):
|
||||
_ = load_config(config_path)
|
||||
|
||||
|
||||
class TestInferRequiredCapabilities:
|
||||
"""Test capability inference from file extensions and target config."""
|
||||
|
||||
def test_plain_image(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
|
||||
def test_image_extensions(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert Capability.TEXT_TO_IMAGE in caps
|
||||
|
||||
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
|
||||
def test_case_insensitive(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert Capability.TEXT_TO_IMAGE in caps
|
||||
|
||||
def test_image_with_reference_images(self) -> None:
|
||||
target = TargetConfig(prompt="x", reference_images=["ref.png"])
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
|
||||
)
|
||||
|
||||
def test_image_with_control_images(self) -> None:
|
||||
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
|
||||
)
|
||||
|
||||
def test_image_with_both(self) -> None:
|
||||
target = TargetConfig(
|
||||
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
|
||||
)
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{
|
||||
Capability.TEXT_TO_IMAGE,
|
||||
Capability.REFERENCE_IMAGES,
|
||||
Capability.CONTROL_IMAGES,
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
||||
def test_text_extensions(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert caps == frozenset({Capability.TEXT_GENERATION})
|
||||
|
||||
def test_text_with_text_inputs(self) -> None:
|
||||
target = TargetConfig(prompt="x", inputs=["data.txt"])
|
||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION}
|
||||
)
|
||||
|
||||
def test_text_with_image_input(self) -> None:
|
||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||
assert infer_required_capabilities("out.txt", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||
)
|
||||
|
||||
def test_text_with_image_reference(self) -> None:
|
||||
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
|
||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||
)
|
||||
|
||||
def test_unsupported_extension_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
with pytest.raises(ValueError, match="unsupported extension"):
|
||||
_ = infer_required_capabilities("data.csv", target)
|
||||
|
||||
def test_no_extension_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
with pytest.raises(ValueError, match="unsupported extension"):
|
||||
_ = infer_required_capabilities("Makefile", target)
|
||||
|
||||
|
||||
class TestResolveModel:
|
||||
"""Test model resolution with capability validation."""
|
||||
|
||||
def test_explicit_model_wins(self) -> None:
|
||||
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
||||
result = resolve_model("out.txt", target, Defaults())
|
||||
assert result.name == "mistral-small-latest"
|
||||
|
||||
def test_default_text_model(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
defaults = Defaults(text_model="mistral-large-latest")
|
||||
result = resolve_model("out.md", target, defaults)
|
||||
assert result.name == "mistral-large-latest"
|
||||
|
||||
def test_default_image_model(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
defaults = Defaults(image_model="flux-dev")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert result.name == "flux-dev"
|
||||
|
||||
def test_unknown_model_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x", model="nonexistent-model")
|
||||
with pytest.raises(ValueError, match="Unknown model"):
|
||||
_ = resolve_model("out.txt", target, Defaults())
|
||||
|
||||
def test_explicit_model_missing_capability_raises(self) -> None:
|
||||
# flux-dev does not support reference images
|
||||
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
|
||||
with pytest.raises(ValueError, match="lacks required capabilities"):
|
||||
_ = resolve_model("out.png", target, Defaults())
|
||||
|
||||
def test_default_fallback_for_reference_images(self) -> None:
|
||||
# Default flux-dev lacks reference_images, should fall back to a capable model
|
||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||
defaults = Defaults(image_model="flux-dev")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert Capability.REFERENCE_IMAGES in result.capabilities
|
||||
|
||||
def test_default_fallback_for_vision(self) -> None:
|
||||
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
|
||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||
defaults = Defaults(text_model="mistral-large-latest")
|
||||
result = resolve_model("out.txt", target, defaults)
|
||||
assert Capability.VISION in result.capabilities
|
||||
|
||||
def test_default_preferred_when_capable(self) -> None:
|
||||
# Default flux-2-pro already supports reference_images, should be used directly
|
||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||
defaults = Defaults(image_model="flux-2-pro")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert result.name == "flux-2-pro"
|
||||
|
|
|
|||
|
|
@ -14,17 +14,18 @@ import pytest
|
|||
|
||||
from bulkgen.config import TargetConfig
|
||||
from bulkgen.providers.bfl import BFLResult
|
||||
from bulkgen.providers.image import ImageProvider
|
||||
from bulkgen.providers.image import (
|
||||
from bulkgen.providers.blackforest import BlackForestProvider
|
||||
from bulkgen.providers.blackforest import (
|
||||
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from bulkgen.providers.models import ALL_MODELS, ModelInfo
|
||||
from bulkgen.providers.text import TextProvider
|
||||
from bulkgen.providers.mistral import MistralProvider
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.providers.registry import get_all_models
|
||||
|
||||
|
||||
def _model(name: str) -> ModelInfo:
|
||||
"""Look up a ModelInfo by name."""
|
||||
for m in ALL_MODELS:
|
||||
for m in get_all_models():
|
||||
if m.name == name:
|
||||
return m
|
||||
msg = f"Unknown test model: {name}"
|
||||
|
|
@ -69,8 +70,8 @@ def _make_text_response(content: str | None) -> MagicMock:
|
|||
return response
|
||||
|
||||
|
||||
class TestImageProvider:
|
||||
"""Test ImageProvider with mocked BFL client and HTTP."""
|
||||
class TestBlackForestProvider:
|
||||
"""Test BlackForestProvider with mocked BFL client and HTTP."""
|
||||
|
||||
@pytest.fixture
|
||||
def image_bytes(self) -> bytes:
|
||||
|
|
@ -83,13 +84,13 @@ class TestImageProvider:
|
|||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||
|
||||
with (
|
||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
||||
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||
):
|
||||
mock_cls.return_value.generate = AsyncMock(return_value=bfl_result)
|
||||
mock_http_cls.return_value = mock_http
|
||||
|
||||
provider = ImageProvider(api_key="test-key")
|
||||
provider = BlackForestProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
|
|
@ -109,14 +110,14 @@ class TestImageProvider:
|
|||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||
|
||||
with (
|
||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
||||
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||
):
|
||||
mock_generate = AsyncMock(return_value=bfl_result)
|
||||
mock_cls.return_value.generate = mock_generate
|
||||
mock_http_cls.return_value = mock_http
|
||||
|
||||
provider = ImageProvider(api_key="test-key")
|
||||
provider = BlackForestProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="banner.png",
|
||||
target_config=target_config,
|
||||
|
|
@ -140,14 +141,14 @@ class TestImageProvider:
|
|||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||
|
||||
with (
|
||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
||||
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||
):
|
||||
mock_generate = AsyncMock(return_value=bfl_result)
|
||||
mock_cls.return_value.generate = mock_generate
|
||||
mock_http_cls.return_value = mock_http
|
||||
|
||||
provider = ImageProvider(api_key="test-key")
|
||||
provider = BlackForestProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
|
|
@ -175,14 +176,14 @@ class TestImageProvider:
|
|||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||
|
||||
with (
|
||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
||||
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||
):
|
||||
mock_generate = AsyncMock(return_value=bfl_result)
|
||||
mock_cls.return_value.generate = mock_generate
|
||||
mock_http_cls.return_value = mock_http
|
||||
|
||||
provider = ImageProvider(api_key="test-key")
|
||||
provider = BlackForestProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
|
|
@ -206,14 +207,14 @@ class TestImageProvider:
|
|||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||
|
||||
with (
|
||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
||||
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||
):
|
||||
mock_generate = AsyncMock(return_value=bfl_result)
|
||||
mock_cls.return_value.generate = mock_generate
|
||||
mock_http_cls.return_value = mock_http
|
||||
|
||||
provider = ImageProvider(api_key="test-key")
|
||||
provider = BlackForestProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
|
|
@ -229,14 +230,14 @@ class TestImageProvider:
|
|||
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
|
||||
target_config = TargetConfig(prompt="x")
|
||||
|
||||
with patch("bulkgen.providers.image.BFLClient") as mock_cls:
|
||||
with patch("bulkgen.providers.blackforest.BFLClient") as mock_cls:
|
||||
from bulkgen.providers.bfl import BFLError
|
||||
|
||||
mock_cls.return_value.generate = AsyncMock(
|
||||
side_effect=BFLError("BFL task test ready but no sample URL: {}")
|
||||
)
|
||||
|
||||
provider = ImageProvider(api_key="test-key")
|
||||
provider = BlackForestProvider(api_key="test-key")
|
||||
with pytest.raises(BFLError, match="no sample URL"):
|
||||
await provider.generate(
|
||||
target_name="fail.png",
|
||||
|
|
@ -255,17 +256,17 @@ class TestImageProvider:
|
|||
assert base64.b64decode(encoded) == data
|
||||
|
||||
|
||||
class TestTextProvider:
|
||||
"""Test TextProvider with mocked Mistral client."""
|
||||
class TestMistralProvider:
|
||||
"""Test MistralProvider with mocked Mistral client."""
|
||||
|
||||
async def test_basic_text_generation(self, project_dir: Path) -> None:
|
||||
target_config = TargetConfig(prompt="Write a poem")
|
||||
response = _make_text_response("Roses are red...")
|
||||
|
||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
||||
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||
mock_cls.return_value = _make_mistral_mock(response)
|
||||
|
||||
provider = TextProvider(api_key="test-key")
|
||||
provider = MistralProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="poem.txt",
|
||||
target_config=target_config,
|
||||
|
|
@ -283,11 +284,11 @@ class TestTextProvider:
|
|||
target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"])
|
||||
response = _make_text_response("Summary: ...")
|
||||
|
||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
||||
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||
mock_client = _make_mistral_mock(response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
provider = TextProvider(api_key="test-key")
|
||||
provider = MistralProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="summary.md",
|
||||
target_config=target_config,
|
||||
|
|
@ -307,11 +308,11 @@ class TestTextProvider:
|
|||
target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"])
|
||||
response = _make_text_response("A beautiful photo")
|
||||
|
||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
||||
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||
mock_client = _make_mistral_mock(response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
provider = TextProvider(api_key="test-key")
|
||||
provider = MistralProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="desc.txt",
|
||||
target_config=target_config,
|
||||
|
|
@ -332,10 +333,10 @@ class TestTextProvider:
|
|||
response = MagicMock()
|
||||
response.choices = []
|
||||
|
||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
||||
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||
mock_cls.return_value = _make_mistral_mock(response)
|
||||
|
||||
provider = TextProvider(api_key="test-key")
|
||||
provider = MistralProvider(api_key="test-key")
|
||||
with pytest.raises(RuntimeError, match="no choices"):
|
||||
await provider.generate(
|
||||
target_name="fail.txt",
|
||||
|
|
@ -349,10 +350,10 @@ class TestTextProvider:
|
|||
target_config = TargetConfig(prompt="x")
|
||||
response = _make_text_response(None)
|
||||
|
||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
||||
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||
mock_cls.return_value = _make_mistral_mock(response)
|
||||
|
||||
provider = TextProvider(api_key="test-key")
|
||||
provider = MistralProvider(api_key="test-key")
|
||||
with pytest.raises(RuntimeError, match="empty content"):
|
||||
await provider.generate(
|
||||
target_name="fail.txt",
|
||||
|
|
@ -372,11 +373,11 @@ class TestTextProvider:
|
|||
)
|
||||
response = _make_text_response("Combined")
|
||||
|
||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
||||
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||
mock_client = _make_mistral_mock(response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
provider = TextProvider(api_key="test-key")
|
||||
provider = MistralProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.md",
|
||||
target_config=target_config,
|
||||
|
|
@ -403,11 +404,11 @@ class TestTextProvider:
|
|||
)
|
||||
response = _make_text_response("A stylized image")
|
||||
|
||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
||||
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||
mock_client = _make_mistral_mock(response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
provider = TextProvider(api_key="test-key")
|
||||
provider = MistralProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="desc.txt",
|
||||
target_config=target_config,
|
||||
|
|
|
|||
145
tests/test_resolve.py
Normal file
145
tests/test_resolve.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Integration tests for bulkgen.config."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from bulkgen.config import (
|
||||
Defaults,
|
||||
TargetConfig,
|
||||
)
|
||||
from bulkgen.providers.models import Capability
|
||||
from bulkgen.resolve import infer_required_capabilities, resolve_model
|
||||
|
||||
|
||||
class TestInferRequiredCapabilities:
|
||||
"""Test capability inference from file extensions and target config."""
|
||||
|
||||
def test_plain_image(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
|
||||
def test_image_extensions(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert Capability.TEXT_TO_IMAGE in caps
|
||||
|
||||
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
|
||||
def test_case_insensitive(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert Capability.TEXT_TO_IMAGE in caps
|
||||
|
||||
def test_image_with_reference_images(self) -> None:
|
||||
target = TargetConfig(prompt="x", reference_images=["ref.png"])
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
|
||||
)
|
||||
|
||||
def test_image_with_control_images(self) -> None:
|
||||
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
|
||||
)
|
||||
|
||||
def test_image_with_both(self) -> None:
|
||||
target = TargetConfig(
|
||||
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
|
||||
)
|
||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||
{
|
||||
Capability.TEXT_TO_IMAGE,
|
||||
Capability.REFERENCE_IMAGES,
|
||||
Capability.CONTROL_IMAGES,
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
||||
def test_text_extensions(self, name: str) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
caps = infer_required_capabilities(name, target)
|
||||
assert caps == frozenset({Capability.TEXT_GENERATION})
|
||||
|
||||
def test_text_with_text_inputs(self) -> None:
|
||||
target = TargetConfig(prompt="x", inputs=["data.txt"])
|
||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION}
|
||||
)
|
||||
|
||||
def test_text_with_image_input(self) -> None:
|
||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||
assert infer_required_capabilities("out.txt", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||
)
|
||||
|
||||
def test_text_with_image_reference(self) -> None:
|
||||
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
|
||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||
)
|
||||
|
||||
def test_unsupported_extension_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
with pytest.raises(ValueError, match="unsupported extension"):
|
||||
_ = infer_required_capabilities("data.csv", target)
|
||||
|
||||
def test_no_extension_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
with pytest.raises(ValueError, match="unsupported extension"):
|
||||
_ = infer_required_capabilities("Makefile", target)
|
||||
|
||||
|
||||
class TestResolveModel:
|
||||
"""Test model resolution with capability validation."""
|
||||
|
||||
def test_explicit_model_wins(self) -> None:
|
||||
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
||||
result = resolve_model("out.txt", target, Defaults())
|
||||
assert result.name == "mistral-small-latest"
|
||||
|
||||
def test_default_text_model(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
defaults = Defaults(text_model="mistral-large-latest")
|
||||
result = resolve_model("out.md", target, defaults)
|
||||
assert result.name == "mistral-large-latest"
|
||||
|
||||
def test_default_image_model(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
defaults = Defaults(image_model="flux-dev")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert result.name == "flux-dev"
|
||||
|
||||
def test_unknown_model_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x", model="nonexistent-model")
|
||||
with pytest.raises(ValueError, match="Unknown model"):
|
||||
_ = resolve_model("out.txt", target, Defaults())
|
||||
|
||||
def test_explicit_model_missing_capability_raises(self) -> None:
|
||||
# flux-dev does not support reference images
|
||||
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
|
||||
with pytest.raises(ValueError, match="lacks required capabilities"):
|
||||
_ = resolve_model("out.png", target, Defaults())
|
||||
|
||||
def test_default_fallback_for_reference_images(self) -> None:
|
||||
# Default flux-dev lacks reference_images, should fall back to a capable model
|
||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||
defaults = Defaults(image_model="flux-dev")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert Capability.REFERENCE_IMAGES in result.capabilities
|
||||
|
||||
def test_default_fallback_for_vision(self) -> None:
|
||||
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
|
||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||
defaults = Defaults(text_model="mistral-large-latest")
|
||||
result = resolve_model("out.txt", target, defaults)
|
||||
assert Capability.VISION in result.capabilities
|
||||
|
||||
def test_default_preferred_when_capable(self) -> None:
|
||||
# Default flux-2-pro already supports reference_images, should be used directly
|
||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||
defaults = Defaults(image_model="flux-2-pro")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert result.name == "flux-2-pro"
|
||||
Loading…
Add table
Add a link
Reference in a new issue