refactor: move model definitions into providers and extract resolve module
Some checks failed
Continuous Integration / Build Package (push) Successful in 30s
Continuous Integration / Lint, Check & Test (push) Failing after 38s

- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider
- Add get_provided_models() abstract method to Provider base class
- Move model lists from models.py into each provider's get_provided_models()
- Add providers/registry.py to aggregate models from all providers
- Extract infer_required_capabilities and resolve_model from config.py to resolve.py
- Update tests to use new names and import paths
This commit is contained in:
Konstantin Fickel 2026-02-15 11:03:57 +01:00
parent dc6a75f5c4
commit d0dac5b1bf
Signed by: kfickel
GPG key ID: A793722F9933C1A5
13 changed files with 432 additions and 390 deletions

View file

@ -12,14 +12,13 @@ from pathlib import Path
from bulkgen.config import ( from bulkgen.config import (
ProjectConfig, ProjectConfig,
TargetType, TargetType,
infer_required_capabilities,
resolve_model,
target_type_from_capabilities, target_type_from_capabilities,
) )
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
from bulkgen.providers import Provider from bulkgen.providers import Provider
from bulkgen.providers.image import ImageProvider from bulkgen.providers.blackforest import BlackForestProvider
from bulkgen.providers.text import TextProvider from bulkgen.providers.mistral import MistralProvider
from bulkgen.resolve import infer_required_capabilities, resolve_model
from bulkgen.state import ( from bulkgen.state import (
BuildState, BuildState,
is_target_dirty, is_target_dirty,
@ -106,10 +105,10 @@ def _create_providers() -> dict[TargetType, Provider]:
providers: dict[TargetType, Provider] = {} providers: dict[TargetType, Provider] = {}
bfl_key = os.environ.get("BFL_API_KEY", "") bfl_key = os.environ.get("BFL_API_KEY", "")
if bfl_key: if bfl_key:
providers[TargetType.IMAGE] = ImageProvider(api_key=bfl_key) providers[TargetType.IMAGE] = BlackForestProvider(api_key=bfl_key)
mistral_key = os.environ.get("MISTRAL_API_KEY", "") mistral_key = os.environ.get("MISTRAL_API_KEY", "")
if mistral_key: if mistral_key:
providers[TargetType.TEXT] = TextProvider(api_key=mistral_key) providers[TargetType.TEXT] = MistralProvider(api_key=mistral_key)
return providers return providers

View file

@ -13,7 +13,7 @@ import typer
from bulkgen.builder import BuildEvent, BuildResult, run_build from bulkgen.builder import BuildEvent, BuildResult, run_build
from bulkgen.config import ProjectConfig, load_config from bulkgen.config import ProjectConfig, load_config
from bulkgen.graph import build_graph, get_build_order from bulkgen.graph import build_graph, get_build_order
from bulkgen.providers.models import ALL_MODELS from bulkgen.providers.registry import get_all_models
app = typer.Typer(name="bulkgen", help="AI artifact build tool.") app = typer.Typer(name="bulkgen", help="AI artifact build tool.")
@ -182,9 +182,10 @@ def graph() -> None:
@app.command() @app.command()
def models() -> None: def models() -> None:
"""List available models and their capabilities.""" """List available models and their capabilities."""
name_width = max(len(m.name) for m in ALL_MODELS) all_models = get_all_models()
provider_width = max(len(m.provider) for m in ALL_MODELS) name_width = max(len(m.name) for m in all_models)
type_width = max(len(m.type) for m in ALL_MODELS) provider_width = max(len(m.provider) for m in all_models)
type_width = max(len(m.type) for m in all_models)
header_name = "Model".ljust(name_width) header_name = "Model".ljust(name_width)
header_provider = "Provider".ljust(provider_width) header_provider = "Provider".ljust(provider_width)
@ -210,7 +211,7 @@ def models() -> None:
+ "" * len(header_caps) + "" * len(header_caps)
) )
for model in ALL_MODELS: for model in all_models:
name_col = model.name.ljust(name_width) name_col = model.name.ljust(name_width)
provider_col = model.provider.ljust(provider_width) provider_col = model.provider.ljust(provider_width)
type_col = model.type.ljust(type_width) type_col = model.type.ljust(type_width)

View file

@ -4,13 +4,15 @@ from __future__ import annotations
import enum import enum
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Self from typing import Self
import yaml import yaml
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
if TYPE_CHECKING: from bulkgen.providers.models import Capability
from bulkgen.providers.models import Capability, ModelInfo
IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"})
TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"})
class TargetType(enum.Enum): class TargetType(enum.Enum):
@ -20,10 +22,6 @@ class TargetType(enum.Enum):
TEXT = "text" TEXT = "text"
IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"})
TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"})
class Defaults(BaseModel): class Defaults(BaseModel):
"""Default model names, applied when a target does not specify its own.""" """Default model names, applied when a target does not specify its own."""
@ -57,36 +55,6 @@ class ProjectConfig(BaseModel):
return self return self
def infer_required_capabilities(
target_name: str, target: TargetConfig
) -> frozenset[Capability]:
"""Infer the capabilities a model must have based on filename and config.
Raises :class:`ValueError` for unsupported file extensions.
"""
from bulkgen.providers.models import Capability
suffix = Path(target_name).suffix.lower()
caps: set[Capability] = set()
if suffix in IMAGE_EXTENSIONS:
caps.add(Capability.TEXT_TO_IMAGE)
if target.reference_images:
caps.add(Capability.REFERENCE_IMAGES)
if target.control_images:
caps.add(Capability.CONTROL_IMAGES)
elif suffix in TEXT_EXTENSIONS:
caps.add(Capability.TEXT_GENERATION)
all_input_names = list(target.inputs) + list(target.reference_images)
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
caps.add(Capability.VISION)
else:
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
raise ValueError(msg)
return frozenset(caps)
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType: def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
"""Derive the target type from a set of required capabilities.""" """Derive the target type from a set of required capabilities."""
from bulkgen.providers.models import Capability from bulkgen.providers.models import Capability
@ -96,59 +64,6 @@ def target_type_from_capabilities(capabilities: frozenset[Capability]) -> Target
return TargetType.TEXT return TargetType.TEXT
def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target, validated against required capabilities.
If the target specifies an explicit model, it is validated to have all
required capabilities. Otherwise the type-appropriate default is tried
first; if it lacks a required capability the first capable model of the
same type is selected.
Raises :class:`ValueError` if no suitable model can be found.
"""
from bulkgen.providers.models import ALL_MODELS
required = infer_required_capabilities(target_name, target)
target_type = target_type_from_capabilities(required)
if target.model is not None:
# Explicit model — look up and validate.
for model in ALL_MODELS:
if model.name == target.model:
missing = required - frozenset(model.capabilities)
if missing:
names = ", ".join(sorted(missing))
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
raise ValueError(msg)
return model
msg = f"Unknown model '{target.model}' for target '{target_name}'"
raise ValueError(msg)
# No explicit model — try the default first, then fall back.
default_name = (
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
)
for model in ALL_MODELS:
if model.name == default_name:
if required <= frozenset(model.capabilities):
return model
break
# Default lacks capabilities — find the first capable model of the same type.
model_type = "image" if target_type is TargetType.IMAGE else "text"
for model in ALL_MODELS:
if model.type == model_type and required <= frozenset(model.capabilities):
return model
names = ", ".join(sorted(required))
msg = (
f"No model found for target '{target_name}' with required capabilities: {names}"
)
raise ValueError(msg)
def load_config(config_path: Path) -> ProjectConfig: def load_config(config_path: Path) -> ProjectConfig:
"""Load and validate a ``.bulkgen.yaml`` file.""" """Load and validate a ``.bulkgen.yaml`` file."""
with config_path.open() as f: with config_path.open() as f:

View file

@ -12,6 +12,11 @@ from bulkgen.providers.models import ModelInfo
class Provider(abc.ABC): class Provider(abc.ABC):
"""Abstract base for generation providers.""" """Abstract base for generation providers."""
@staticmethod
@abc.abstractmethod
def get_provided_models() -> list[ModelInfo]:
"""Return the models this provider supports."""
@abc.abstractmethod @abc.abstractmethod
async def generate( async def generate(
self, self,

View file

@ -11,7 +11,7 @@ import httpx
from bulkgen.config import TargetConfig from bulkgen.config import TargetConfig
from bulkgen.providers import Provider from bulkgen.providers import Provider
from bulkgen.providers.bfl import BFLClient from bulkgen.providers.bfl import BFLClient
from bulkgen.providers.models import ModelInfo from bulkgen.providers.models import Capability, ModelInfo
def _encode_image_b64(path: Path) -> str: def _encode_image_b64(path: Path) -> str:
@ -45,7 +45,7 @@ def _add_reference_images(
inputs[key] = _encode_image_b64(project_dir / ref_name) inputs[key] = _encode_image_b64(project_dir / ref_name)
class ImageProvider(Provider): class BlackForestProvider(Provider):
"""Generates images via the BlackForestLabs API.""" """Generates images via the BlackForestLabs API."""
_client: BFLClient _client: BFLClient
@ -53,6 +53,72 @@ class ImageProvider(Provider):
def __init__(self, api_key: str) -> None: def __init__(self, api_key: str) -> None:
self._client = BFLClient(api_key=api_key) self._client = BFLClient(api_key=api_key)
@staticmethod
@override
def get_provided_models() -> list[ModelInfo]:
return [
ModelInfo(
name="flux-dev",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1-ultra",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-2-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-kontext-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-canny",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-depth",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-fill",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.0-expand",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
]
@override @override
async def generate( async def generate(
self, self,

View file

@ -11,7 +11,7 @@ from mistralai import Mistral, models
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
from bulkgen.providers import Provider from bulkgen.providers import Provider
from bulkgen.providers.models import ModelInfo from bulkgen.providers.models import Capability, ModelInfo
def _image_to_data_url(path: Path) -> str: def _image_to_data_url(path: Path) -> str:
@ -21,7 +21,7 @@ def _image_to_data_url(path: Path) -> str:
return f"data:{mime};base64,{b64}" return f"data:{mime};base64,{b64}"
class TextProvider(Provider): class MistralProvider(Provider):
"""Generates text via the Mistral API.""" """Generates text via the Mistral API."""
_api_key: str _api_key: str
@ -29,6 +29,36 @@ class TextProvider(Provider):
def __init__(self, api_key: str) -> None: def __init__(self, api_key: str) -> None:
self._api_key = api_key self._api_key = api_key
@staticmethod
@override
def get_provided_models() -> list[ModelInfo]:
return [
ModelInfo(
name="mistral-large-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION],
),
ModelInfo(
name="mistral-small-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION],
),
ModelInfo(
name="pixtral-large-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
),
ModelInfo(
name="pixtral-12b-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
),
]
@override @override
async def generate( async def generate(
self, self,
@ -98,6 +128,8 @@ def _build_multimodal_message(
models.TextChunk(text=prompt), models.TextChunk(text=prompt),
] ]
from bulkgen.config import IMAGE_EXTENSIONS
for name in input_names: for name in input_names:
input_path = project_dir / name input_path = project_dir / name
suffix = input_path.suffix.lower() suffix = input_path.suffix.lower()

View file

@ -1,4 +1,4 @@
"""Registry of supported models and their capabilities.""" """Model types and capability definitions for AI providers."""
from __future__ import annotations from __future__ import annotations
@ -25,106 +25,3 @@ class ModelInfo:
provider: str provider: str
type: Literal["text", "image"] type: Literal["text", "image"]
capabilities: list[Capability] capabilities: list[Capability]
TEXT_MODELS: list[ModelInfo] = [
ModelInfo(
name="mistral-large-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION],
),
ModelInfo(
name="mistral-small-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION],
),
ModelInfo(
name="pixtral-large-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
),
ModelInfo(
name="pixtral-12b-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
),
]
IMAGE_MODELS: list[ModelInfo] = [
ModelInfo(
name="flux-dev",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1-ultra",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-2-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-kontext-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-canny",
provider="BlackForestLabs",
type="image",
capabilities=[
Capability.TEXT_TO_IMAGE,
Capability.CONTROL_IMAGES,
],
),
ModelInfo(
name="flux-pro-1.0-depth",
provider="BlackForestLabs",
type="image",
capabilities=[
Capability.TEXT_TO_IMAGE,
Capability.CONTROL_IMAGES,
],
),
ModelInfo(
name="flux-pro-1.0-fill",
provider="BlackForestLabs",
type="image",
capabilities=[
Capability.TEXT_TO_IMAGE,
],
),
ModelInfo(
name="flux-pro-1.0-expand",
provider="BlackForestLabs",
type="image",
capabilities=[
Capability.TEXT_TO_IMAGE,
],
),
]
ALL_MODELS: list[ModelInfo] = TEXT_MODELS + IMAGE_MODELS

View file

@ -0,0 +1,16 @@
"""Aggregates models from all registered providers."""
from __future__ import annotations
from bulkgen.providers.models import ModelInfo
def get_all_models() -> list[ModelInfo]:
"""Return the merged list of models from all providers."""
from bulkgen.providers.blackforest import BlackForestProvider
from bulkgen.providers.mistral import MistralProvider
return (
MistralProvider.get_provided_models()
+ BlackForestProvider.get_provided_models()
)

95
bulkgen/resolve.py Normal file
View file

@ -0,0 +1,95 @@
from __future__ import annotations
from pathlib import Path
from bulkgen.config import (
IMAGE_EXTENSIONS,
TEXT_EXTENSIONS,
Defaults,
TargetConfig,
TargetType,
target_type_from_capabilities,
)
from bulkgen.providers.models import Capability, ModelInfo
def infer_required_capabilities(
target_name: str, target: TargetConfig
) -> frozenset[Capability]:
"""Infer the capabilities a model must have based on filename and config.
Raises :class:`ValueError` for unsupported file extensions.
"""
suffix = Path(target_name).suffix.lower()
caps: set[Capability] = set()
if suffix in IMAGE_EXTENSIONS:
caps.add(Capability.TEXT_TO_IMAGE)
if target.reference_images:
caps.add(Capability.REFERENCE_IMAGES)
if target.control_images:
caps.add(Capability.CONTROL_IMAGES)
elif suffix in TEXT_EXTENSIONS:
caps.add(Capability.TEXT_GENERATION)
all_input_names = list(target.inputs) + list(target.reference_images)
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
caps.add(Capability.VISION)
else:
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
raise ValueError(msg)
return frozenset(caps)
def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target, validated against required capabilities.
If the target specifies an explicit model, it is validated to have all
required capabilities. Otherwise the type-appropriate default is tried
first; if it lacks a required capability the first capable model of the
same type is selected.
Raises :class:`ValueError` if no suitable model can be found.
"""
from bulkgen.providers.registry import get_all_models
all_models = get_all_models()
required = infer_required_capabilities(target_name, target)
target_type = target_type_from_capabilities(required)
if target.model is not None:
# Explicit model — look up and validate.
for model in all_models:
if model.name == target.model:
missing = required - frozenset(model.capabilities)
if missing:
names = ", ".join(sorted(missing))
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
raise ValueError(msg)
return model
msg = f"Unknown model '{target.model}' for target '{target_name}'"
raise ValueError(msg)
# No explicit model — try the default first, then fall back.
default_name = (
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
)
for model in all_models:
if model.name == default_name:
if required <= frozenset(model.capabilities):
return model
break
# Default lacks capabilities — find the first capable model of the same type.
model_type = "image" if target_type is TargetType.IMAGE else "text"
for model in all_models:
if model.type == model_type and required <= frozenset(model.capabilities):
return model
names = ", ".join(sorted(required))
msg = (
f"No model found for target '{target_name}' with required capabilities: {names}"
)
raise ValueError(msg)

View file

@ -27,6 +27,11 @@ WriteConfig = Callable[[dict[str, object]], ProjectConfig]
class FakeProvider(Provider): class FakeProvider(Provider):
"""A provider that writes a marker file instead of calling an API.""" """A provider that writes a marker file instead of calling an API."""
@staticmethod
@override
def get_provided_models() -> list[ModelInfo]:
return []
@override @override
async def generate( async def generate(
self, self,
@ -43,6 +48,11 @@ class FakeProvider(Provider):
class FailingProvider(Provider): class FailingProvider(Provider):
"""A provider that always raises.""" """A provider that always raises."""
@staticmethod
@override
def get_provided_models() -> list[ModelInfo]:
return []
@override @override
async def generate( async def generate(
self, self,

View file

@ -7,14 +7,7 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from bulkgen.config import ( from bulkgen.config import load_config
Defaults,
TargetConfig,
infer_required_capabilities,
load_config,
resolve_model,
)
from bulkgen.providers.models import Capability
class TestLoadConfig: class TestLoadConfig:
@ -86,136 +79,3 @@ class TestLoadConfig:
with pytest.raises(Exception): with pytest.raises(Exception):
_ = load_config(config_path) _ = load_config(config_path)
class TestInferRequiredCapabilities:
"""Test capability inference from file extensions and target config."""
def test_plain_image(self) -> None:
target = TargetConfig(prompt="x")
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE}
)
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
def test_image_extensions(self, name: str) -> None:
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
def test_case_insensitive(self, name: str) -> None:
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps
def test_image_with_reference_images(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.png"])
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
)
def test_image_with_control_images(self) -> None:
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
)
def test_image_with_both(self) -> None:
target = TargetConfig(
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
)
assert infer_required_capabilities("out.png", target) == frozenset(
{
Capability.TEXT_TO_IMAGE,
Capability.REFERENCE_IMAGES,
Capability.CONTROL_IMAGES,
}
)
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
def test_text_extensions(self, name: str) -> None:
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert caps == frozenset({Capability.TEXT_GENERATION})
def test_text_with_text_inputs(self) -> None:
target = TargetConfig(prompt="x", inputs=["data.txt"])
assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION}
)
def test_text_with_image_input(self) -> None:
target = TargetConfig(prompt="x", inputs=["photo.png"])
assert infer_required_capabilities("out.txt", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION}
)
def test_text_with_image_reference(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION}
)
def test_unsupported_extension_raises(self) -> None:
target = TargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_required_capabilities("data.csv", target)
def test_no_extension_raises(self) -> None:
target = TargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_required_capabilities("Makefile", target)
class TestResolveModel:
"""Test model resolution with capability validation."""
def test_explicit_model_wins(self) -> None:
target = TargetConfig(prompt="x", model="mistral-small-latest")
result = resolve_model("out.txt", target, Defaults())
assert result.name == "mistral-small-latest"
def test_default_text_model(self) -> None:
target = TargetConfig(prompt="x")
defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.md", target, defaults)
assert result.name == "mistral-large-latest"
def test_default_image_model(self) -> None:
target = TargetConfig(prompt="x")
defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults)
assert result.name == "flux-dev"
def test_unknown_model_raises(self) -> None:
target = TargetConfig(prompt="x", model="nonexistent-model")
with pytest.raises(ValueError, match="Unknown model"):
_ = resolve_model("out.txt", target, Defaults())
def test_explicit_model_missing_capability_raises(self) -> None:
# flux-dev does not support reference images
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
with pytest.raises(ValueError, match="lacks required capabilities"):
_ = resolve_model("out.png", target, Defaults())
def test_default_fallback_for_reference_images(self) -> None:
# Default flux-dev lacks reference_images, should fall back to a capable model
target = TargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults)
assert Capability.REFERENCE_IMAGES in result.capabilities
def test_default_fallback_for_vision(self) -> None:
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
target = TargetConfig(prompt="x", inputs=["photo.png"])
defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.txt", target, defaults)
assert Capability.VISION in result.capabilities
def test_default_preferred_when_capable(self) -> None:
# Default flux-2-pro already supports reference_images, should be used directly
target = TargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-2-pro")
result = resolve_model("out.png", target, defaults)
assert result.name == "flux-2-pro"

View file

@ -14,17 +14,18 @@ import pytest
from bulkgen.config import TargetConfig from bulkgen.config import TargetConfig
from bulkgen.providers.bfl import BFLResult from bulkgen.providers.bfl import BFLResult
from bulkgen.providers.image import ImageProvider from bulkgen.providers.blackforest import BlackForestProvider
from bulkgen.providers.image import ( from bulkgen.providers.blackforest import (
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage] _encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
) )
from bulkgen.providers.models import ALL_MODELS, ModelInfo from bulkgen.providers.mistral import MistralProvider
from bulkgen.providers.text import TextProvider from bulkgen.providers.models import ModelInfo
from bulkgen.providers.registry import get_all_models
def _model(name: str) -> ModelInfo: def _model(name: str) -> ModelInfo:
"""Look up a ModelInfo by name.""" """Look up a ModelInfo by name."""
for m in ALL_MODELS: for m in get_all_models():
if m.name == name: if m.name == name:
return m return m
msg = f"Unknown test model: {name}" msg = f"Unknown test model: {name}"
@ -69,8 +70,8 @@ def _make_text_response(content: str | None) -> MagicMock:
return response return response
class TestImageProvider: class TestBlackForestProvider:
"""Test ImageProvider with mocked BFL client and HTTP.""" """Test BlackForestProvider with mocked BFL client and HTTP."""
@pytest.fixture @pytest.fixture
def image_bytes(self) -> bytes: def image_bytes(self) -> bytes:
@ -83,13 +84,13 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
patch("bulkgen.providers.image.BFLClient") as mock_cls, patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
): ):
mock_cls.return_value.generate = AsyncMock(return_value=bfl_result) mock_cls.return_value.generate = AsyncMock(return_value=bfl_result)
mock_http_cls.return_value = mock_http mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key") provider = BlackForestProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="out.png", target_name="out.png",
target_config=target_config, target_config=target_config,
@ -109,14 +110,14 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
patch("bulkgen.providers.image.BFLClient") as mock_cls, patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
): ):
mock_generate = AsyncMock(return_value=bfl_result) mock_generate = AsyncMock(return_value=bfl_result)
mock_cls.return_value.generate = mock_generate mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key") provider = BlackForestProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="banner.png", target_name="banner.png",
target_config=target_config, target_config=target_config,
@ -140,14 +141,14 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
patch("bulkgen.providers.image.BFLClient") as mock_cls, patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
): ):
mock_generate = AsyncMock(return_value=bfl_result) mock_generate = AsyncMock(return_value=bfl_result)
mock_cls.return_value.generate = mock_generate mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key") provider = BlackForestProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="out.png", target_name="out.png",
target_config=target_config, target_config=target_config,
@ -175,14 +176,14 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
patch("bulkgen.providers.image.BFLClient") as mock_cls, patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
): ):
mock_generate = AsyncMock(return_value=bfl_result) mock_generate = AsyncMock(return_value=bfl_result)
mock_cls.return_value.generate = mock_generate mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key") provider = BlackForestProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="out.png", target_name="out.png",
target_config=target_config, target_config=target_config,
@ -206,14 +207,14 @@ class TestImageProvider:
bfl_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
patch("bulkgen.providers.image.BFLClient") as mock_cls, patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
): ):
mock_generate = AsyncMock(return_value=bfl_result) mock_generate = AsyncMock(return_value=bfl_result)
mock_cls.return_value.generate = mock_generate mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key") provider = BlackForestProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="out.png", target_name="out.png",
target_config=target_config, target_config=target_config,
@ -229,14 +230,14 @@ class TestImageProvider:
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None: async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x") target_config = TargetConfig(prompt="x")
with patch("bulkgen.providers.image.BFLClient") as mock_cls: with patch("bulkgen.providers.blackforest.BFLClient") as mock_cls:
from bulkgen.providers.bfl import BFLError from bulkgen.providers.bfl import BFLError
mock_cls.return_value.generate = AsyncMock( mock_cls.return_value.generate = AsyncMock(
side_effect=BFLError("BFL task test ready but no sample URL: {}") side_effect=BFLError("BFL task test ready but no sample URL: {}")
) )
provider = ImageProvider(api_key="test-key") provider = BlackForestProvider(api_key="test-key")
with pytest.raises(BFLError, match="no sample URL"): with pytest.raises(BFLError, match="no sample URL"):
await provider.generate( await provider.generate(
target_name="fail.png", target_name="fail.png",
@ -255,17 +256,17 @@ class TestImageProvider:
assert base64.b64decode(encoded) == data assert base64.b64decode(encoded) == data
class TestTextProvider: class TestMistralProvider:
"""Test TextProvider with mocked Mistral client.""" """Test MistralProvider with mocked Mistral client."""
async def test_basic_text_generation(self, project_dir: Path) -> None: async def test_basic_text_generation(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="Write a poem") target_config = TargetConfig(prompt="Write a poem")
response = _make_text_response("Roses are red...") response = _make_text_response("Roses are red...")
with patch("bulkgen.providers.text.Mistral") as mock_cls: with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_cls.return_value = _make_mistral_mock(response) mock_cls.return_value = _make_mistral_mock(response)
provider = TextProvider(api_key="test-key") provider = MistralProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="poem.txt", target_name="poem.txt",
target_config=target_config, target_config=target_config,
@ -283,11 +284,11 @@ class TestTextProvider:
target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"]) target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"])
response = _make_text_response("Summary: ...") response = _make_text_response("Summary: ...")
with patch("bulkgen.providers.text.Mistral") as mock_cls: with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_client = _make_mistral_mock(response) mock_client = _make_mistral_mock(response)
mock_cls.return_value = mock_client mock_cls.return_value = mock_client
provider = TextProvider(api_key="test-key") provider = MistralProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="summary.md", target_name="summary.md",
target_config=target_config, target_config=target_config,
@ -307,11 +308,11 @@ class TestTextProvider:
target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"]) target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"])
response = _make_text_response("A beautiful photo") response = _make_text_response("A beautiful photo")
with patch("bulkgen.providers.text.Mistral") as mock_cls: with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_client = _make_mistral_mock(response) mock_client = _make_mistral_mock(response)
mock_cls.return_value = mock_client mock_cls.return_value = mock_client
provider = TextProvider(api_key="test-key") provider = MistralProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="desc.txt", target_name="desc.txt",
target_config=target_config, target_config=target_config,
@ -332,10 +333,10 @@ class TestTextProvider:
response = MagicMock() response = MagicMock()
response.choices = [] response.choices = []
with patch("bulkgen.providers.text.Mistral") as mock_cls: with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_cls.return_value = _make_mistral_mock(response) mock_cls.return_value = _make_mistral_mock(response)
provider = TextProvider(api_key="test-key") provider = MistralProvider(api_key="test-key")
with pytest.raises(RuntimeError, match="no choices"): with pytest.raises(RuntimeError, match="no choices"):
await provider.generate( await provider.generate(
target_name="fail.txt", target_name="fail.txt",
@ -349,10 +350,10 @@ class TestTextProvider:
target_config = TargetConfig(prompt="x") target_config = TargetConfig(prompt="x")
response = _make_text_response(None) response = _make_text_response(None)
with patch("bulkgen.providers.text.Mistral") as mock_cls: with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_cls.return_value = _make_mistral_mock(response) mock_cls.return_value = _make_mistral_mock(response)
provider = TextProvider(api_key="test-key") provider = MistralProvider(api_key="test-key")
with pytest.raises(RuntimeError, match="empty content"): with pytest.raises(RuntimeError, match="empty content"):
await provider.generate( await provider.generate(
target_name="fail.txt", target_name="fail.txt",
@ -372,11 +373,11 @@ class TestTextProvider:
) )
response = _make_text_response("Combined") response = _make_text_response("Combined")
with patch("bulkgen.providers.text.Mistral") as mock_cls: with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_client = _make_mistral_mock(response) mock_client = _make_mistral_mock(response)
mock_cls.return_value = mock_client mock_cls.return_value = mock_client
provider = TextProvider(api_key="test-key") provider = MistralProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="out.md", target_name="out.md",
target_config=target_config, target_config=target_config,
@ -403,11 +404,11 @@ class TestTextProvider:
) )
response = _make_text_response("A stylized image") response = _make_text_response("A stylized image")
with patch("bulkgen.providers.text.Mistral") as mock_cls: with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
mock_client = _make_mistral_mock(response) mock_client = _make_mistral_mock(response)
mock_cls.return_value = mock_client mock_cls.return_value = mock_client
provider = TextProvider(api_key="test-key") provider = MistralProvider(api_key="test-key")
await provider.generate( await provider.generate(
target_name="desc.txt", target_name="desc.txt",
target_config=target_config, target_config=target_config,

145
tests/test_resolve.py Normal file
View file

@ -0,0 +1,145 @@
"""Integration tests for bulkgen.config."""
from __future__ import annotations
import pytest
from bulkgen.config import (
Defaults,
TargetConfig,
)
from bulkgen.providers.models import Capability
from bulkgen.resolve import infer_required_capabilities, resolve_model
class TestInferRequiredCapabilities:
"""Test capability inference from file extensions and target config."""
def test_plain_image(self) -> None:
target = TargetConfig(prompt="x")
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE}
)
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
def test_image_extensions(self, name: str) -> None:
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
def test_case_insensitive(self, name: str) -> None:
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert Capability.TEXT_TO_IMAGE in caps
def test_image_with_reference_images(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.png"])
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
)
def test_image_with_control_images(self) -> None:
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
assert infer_required_capabilities("out.png", target) == frozenset(
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
)
def test_image_with_both(self) -> None:
target = TargetConfig(
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
)
assert infer_required_capabilities("out.png", target) == frozenset(
{
Capability.TEXT_TO_IMAGE,
Capability.REFERENCE_IMAGES,
Capability.CONTROL_IMAGES,
}
)
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
def test_text_extensions(self, name: str) -> None:
target = TargetConfig(prompt="x")
caps = infer_required_capabilities(name, target)
assert caps == frozenset({Capability.TEXT_GENERATION})
def test_text_with_text_inputs(self) -> None:
target = TargetConfig(prompt="x", inputs=["data.txt"])
assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION}
)
def test_text_with_image_input(self) -> None:
target = TargetConfig(prompt="x", inputs=["photo.png"])
assert infer_required_capabilities("out.txt", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION}
)
def test_text_with_image_reference(self) -> None:
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
assert infer_required_capabilities("out.md", target) == frozenset(
{Capability.TEXT_GENERATION, Capability.VISION}
)
def test_unsupported_extension_raises(self) -> None:
target = TargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_required_capabilities("data.csv", target)
def test_no_extension_raises(self) -> None:
target = TargetConfig(prompt="x")
with pytest.raises(ValueError, match="unsupported extension"):
_ = infer_required_capabilities("Makefile", target)
class TestResolveModel:
"""Test model resolution with capability validation."""
def test_explicit_model_wins(self) -> None:
target = TargetConfig(prompt="x", model="mistral-small-latest")
result = resolve_model("out.txt", target, Defaults())
assert result.name == "mistral-small-latest"
def test_default_text_model(self) -> None:
target = TargetConfig(prompt="x")
defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.md", target, defaults)
assert result.name == "mistral-large-latest"
def test_default_image_model(self) -> None:
target = TargetConfig(prompt="x")
defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults)
assert result.name == "flux-dev"
def test_unknown_model_raises(self) -> None:
target = TargetConfig(prompt="x", model="nonexistent-model")
with pytest.raises(ValueError, match="Unknown model"):
_ = resolve_model("out.txt", target, Defaults())
def test_explicit_model_missing_capability_raises(self) -> None:
# flux-dev does not support reference images
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
with pytest.raises(ValueError, match="lacks required capabilities"):
_ = resolve_model("out.png", target, Defaults())
def test_default_fallback_for_reference_images(self) -> None:
# Default flux-dev lacks reference_images, should fall back to a capable model
target = TargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults)
assert Capability.REFERENCE_IMAGES in result.capabilities
def test_default_fallback_for_vision(self) -> None:
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
target = TargetConfig(prompt="x", inputs=["photo.png"])
defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.txt", target, defaults)
assert Capability.VISION in result.capabilities
def test_default_preferred_when_capable(self) -> None:
# Default flux-2-pro already supports reference_images, should be used directly
target = TargetConfig(prompt="x", reference_images=["r.png"])
defaults = Defaults(image_model="flux-2-pro")
result = resolve_model("out.png", target, defaults)
assert result.name == "flux-2-pro"