refactor: move model definitions into providers and extract resolve module
Some checks failed
Continuous Integration / Build Package (push) Successful in 30s
Continuous Integration / Lint, Check & Test (push) Failing after 38s

- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider
- Add get_provided_models() abstract method to Provider base class
- Move model lists from models.py into each provider's get_provided_models()
- Add providers/registry.py to aggregate models from all providers
- Extract infer_required_capabilities and resolve_model from config.py to resolve.py
- Update tests to use new names and import paths
This commit is contained in:
Konstantin Fickel 2026-02-15 11:03:57 +01:00
parent dc6a75f5c4
commit d0dac5b1bf
Signed by: kfickel
GPG key ID: A793722F9933C1A5
13 changed files with 432 additions and 390 deletions

View file

@ -14,17 +14,18 @@ import pytest
from bulkgen.config import TargetConfig
from bulkgen.providers.bfl import BFLResult
from bulkgen.providers.image import ImageProvider
from bulkgen.providers.image import (
from bulkgen.providers.blackforest import BlackForestProvider
from bulkgen.providers.blackforest import (
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
)
from bulkgen.providers.models import ALL_MODELS, ModelInfo
from bulkgen.providers.text import TextProvider
from bulkgen.providers.mistral import MistralProvider
from bulkgen.providers.models import ModelInfo
from bulkgen.providers.registry import get_all_models
def _model(name: str) -> ModelInfo:
"""Look up a ModelInfo by name."""
for m in ALL_MODELS:
for m in get_all_models():
if m.name == name:
return m
msg = f"Unknown test model: {name}"
@ -69,8 +70,8 @@ def _make_text_response(content: str | None) -> MagicMock:
return response
class TestImageProvider:
"""Test ImageProvider with mocked BFL client and HTTP."""
class TestBlackForestProvider:
"""Test BlackForestProvider with mocked BFL client and HTTP."""
@pytest.fixture
def image_bytes(self) -> bytes:
@ -83,13 +84,13 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
patch("bulkgen.providers.image.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
):
mock_cls.return_value.generate = AsyncMock(return_value=bfl_result)
mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key")
provider = BlackForestProvider(api_key="test-key")
await provider.generate(
target_name="out.png",
target_config=target_config,
@ -109,14 +110,14 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
patch("bulkgen.providers.image.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
):
mock_generate = AsyncMock(return_value=bfl_result)
mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key")
provider = BlackForestProvider(api_key="test-key")
await provider.generate(
target_name="banner.png",
target_config=target_config,
@ -140,14 +141,14 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
patch("bulkgen.providers.image.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
):
mock_generate = AsyncMock(return_value=bfl_result)
mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key")
provider = BlackForestProvider(api_key="test-key")
await provider.generate(
target_name="out.png",
target_config=target_config,
@ -175,14 +176,14 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
patch("bulkgen.providers.image.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
):
mock_generate = AsyncMock(return_value=bfl_result)
mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key")
provider = BlackForestProvider(api_key="test-key")
await provider.generate(
target_name="out.png",
target_config=target_config,
@ -206,14 +207,14 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with (
patch("bulkgen.providers.image.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
):
mock_generate = AsyncMock(return_value=bfl_result)
mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key")
provider = BlackForestProvider(api_key="test-key")
await provider.generate(
target_name="out.png",
target_config=target_config,
@ -229,14 +230,14 @@ class TestImageProvider:
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x")
with patch("bulkgen.providers.image.BFLClient") as mock_cls:
with patch("bulkgen.providers.blackforest.BFLClient") as mock_cls:
from bulkgen.providers.bfl import BFLError
mock_cls.return_value.generate = AsyncMock(
side_effect=BFLError("BFL task test ready but no sample URL: {}")
)
provider = ImageProvider(api_key="test-key")
provider = BlackForestProvider(api_key="test-key")
with pytest.raises(BFLError, match="no sample URL"):
await provider.generate(
target_name="fail.png",
@ -255,17 +256,17 @@ class TestImageProvider:
assert base64.b64decode(encoded) == data
class TestTextProvider:
"""Test TextProvider with mocked Mistral client."""
class TestMistralProvider:
"""Test MistralProvider with mocked Mistral client."""
async def test_basic_text_generation(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="Write a poem")
response = _make_text_response("Roses are red...")
with patch("bulkgen.providers.text.Mistral") as mock_cls:
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_cls.return_value = _make_mistral_mock(response)
provider = TextProvider(api_key="test-key")
provider = MistralProvider(api_key="test-key")
await provider.generate(
target_name="poem.txt",
target_config=target_config,
@ -283,11 +284,11 @@ class TestTextProvider:
target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"])
response = _make_text_response("Summary: ...")
with patch("bulkgen.providers.text.Mistral") as mock_cls:
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_client = _make_mistral_mock(response)
mock_cls.return_value = mock_client
provider = TextProvider(api_key="test-key")
provider = MistralProvider(api_key="test-key")
await provider.generate(
target_name="summary.md",
target_config=target_config,
@ -307,11 +308,11 @@ class TestTextProvider:
target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"])
response = _make_text_response("A beautiful photo")
with patch("bulkgen.providers.text.Mistral") as mock_cls:
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_client = _make_mistral_mock(response)
mock_cls.return_value = mock_client
provider = TextProvider(api_key="test-key")
provider = MistralProvider(api_key="test-key")
await provider.generate(
target_name="desc.txt",
target_config=target_config,
@ -332,10 +333,10 @@ class TestTextProvider:
response = MagicMock()
response.choices = []
with patch("bulkgen.providers.text.Mistral") as mock_cls:
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_cls.return_value = _make_mistral_mock(response)
provider = TextProvider(api_key="test-key")
provider = MistralProvider(api_key="test-key")
with pytest.raises(RuntimeError, match="no choices"):
await provider.generate(
target_name="fail.txt",
@ -349,10 +350,10 @@ class TestTextProvider:
target_config = TargetConfig(prompt="x")
response = _make_text_response(None)
with patch("bulkgen.providers.text.Mistral") as mock_cls:
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_cls.return_value = _make_mistral_mock(response)
provider = TextProvider(api_key="test-key")
provider = MistralProvider(api_key="test-key")
with pytest.raises(RuntimeError, match="empty content"):
await provider.generate(
target_name="fail.txt",
@ -372,11 +373,11 @@ class TestTextProvider:
)
response = _make_text_response("Combined")
with patch("bulkgen.providers.text.Mistral") as mock_cls:
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_client = _make_mistral_mock(response)
mock_cls.return_value = mock_client
provider = TextProvider(api_key="test-key")
provider = MistralProvider(api_key="test-key")
await provider.generate(
target_name="out.md",
target_config=target_config,
@ -403,11 +404,11 @@ class TestTextProvider:
)
response = _make_text_response("A stylized image")
with patch("bulkgen.providers.text.Mistral") as mock_cls:
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_client = _make_mistral_mock(response)
mock_cls.return_value = mock_client
provider = TextProvider(api_key="test-key")
provider = MistralProvider(api_key="test-key")
await provider.generate(
target_name="desc.txt",
target_config=target_config,