diff --git a/bulkgen/builder.py b/bulkgen/builder.py index 9439353..0d108e4 100644 --- a/bulkgen/builder.py +++ b/bulkgen/builder.py @@ -12,14 +12,13 @@ from pathlib import Path from bulkgen.config import ( ProjectConfig, TargetType, - infer_required_capabilities, - resolve_model, target_type_from_capabilities, ) from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target from bulkgen.providers import Provider -from bulkgen.providers.image import ImageProvider -from bulkgen.providers.text import TextProvider +from bulkgen.providers.blackforest import BlackForestProvider +from bulkgen.providers.mistral import MistralProvider +from bulkgen.resolve import infer_required_capabilities, resolve_model from bulkgen.state import ( BuildState, is_target_dirty, @@ -106,10 +105,10 @@ def _create_providers() -> dict[TargetType, Provider]: providers: dict[TargetType, Provider] = {} bfl_key = os.environ.get("BFL_API_KEY", "") if bfl_key: - providers[TargetType.IMAGE] = ImageProvider(api_key=bfl_key) + providers[TargetType.IMAGE] = BlackForestProvider(api_key=bfl_key) mistral_key = os.environ.get("MISTRAL_API_KEY", "") if mistral_key: - providers[TargetType.TEXT] = TextProvider(api_key=mistral_key) + providers[TargetType.TEXT] = MistralProvider(api_key=mistral_key) return providers diff --git a/bulkgen/cli.py b/bulkgen/cli.py index a41933a..5e58d20 100644 --- a/bulkgen/cli.py +++ b/bulkgen/cli.py @@ -13,7 +13,7 @@ import typer from bulkgen.builder import BuildEvent, BuildResult, run_build from bulkgen.config import ProjectConfig, load_config from bulkgen.graph import build_graph, get_build_order -from bulkgen.providers.models import ALL_MODELS +from bulkgen.providers.registry import get_all_models app = typer.Typer(name="bulkgen", help="AI artifact build tool.") @@ -182,9 +182,10 @@ def graph() -> None: @app.command() def models() -> None: """List available models and their capabilities.""" - name_width = max(len(m.name) for m in ALL_MODELS) - provider_width = max(len(m.provider) for m in ALL_MODELS) - type_width = max(len(m.type) for m in ALL_MODELS) + all_models = get_all_models() + name_width = max(len(m.name) for m in all_models) + provider_width = max(len(m.provider) for m in all_models) + type_width = max(len(m.type) for m in all_models) header_name = "Model".ljust(name_width) header_provider = "Provider".ljust(provider_width) @@ -210,7 +211,7 @@ def models() -> None: + "─" * len(header_caps) ) - for model in ALL_MODELS: + for model in all_models: name_col = model.name.ljust(name_width) provider_col = model.provider.ljust(provider_width) type_col = model.type.ljust(type_width) diff --git a/bulkgen/config.py b/bulkgen/config.py index 8d0f9d9..20b929d 100644 --- a/bulkgen/config.py +++ b/bulkgen/config.py @@ -4,13 +4,15 @@ from __future__ import annotations import enum from pathlib import Path -from typing import TYPE_CHECKING, Self +from typing import Self import yaml from pydantic import BaseModel, model_validator -if TYPE_CHECKING: - from bulkgen.providers.models import Capability, ModelInfo +from bulkgen.providers.models import Capability + +IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"}) +TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"}) class TargetType(enum.Enum): @@ -20,10 +22,6 @@ class TargetType(enum.Enum): TEXT = "text" -IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"}) -TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"}) - - class Defaults(BaseModel): """Default model names, applied when a target does not specify its own.""" @@ -57,36 +55,6 @@ class ProjectConfig(BaseModel): return self -def infer_required_capabilities( - target_name: str, target: TargetConfig -) -> frozenset[Capability]: - """Infer the capabilities a model must have based on filename and config. - - Raises :class:`ValueError` for unsupported file extensions. - """ - from bulkgen.providers.models import Capability - - suffix = Path(target_name).suffix.lower() - caps: set[Capability] = set() - - if suffix in IMAGE_EXTENSIONS: - caps.add(Capability.TEXT_TO_IMAGE) - if target.reference_images: - caps.add(Capability.REFERENCE_IMAGES) - if target.control_images: - caps.add(Capability.CONTROL_IMAGES) - elif suffix in TEXT_EXTENSIONS: - caps.add(Capability.TEXT_GENERATION) - all_input_names = list(target.inputs) + list(target.reference_images) - if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names): - caps.add(Capability.VISION) - else: - msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'" - raise ValueError(msg) - - return frozenset(caps) - - def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType: """Derive the target type from a set of required capabilities.""" from bulkgen.providers.models import Capability @@ -96,59 +64,6 @@ def target_type_from_capabilities(capabilities: frozenset[Capability]) -> Target return TargetType.TEXT -def resolve_model( - target_name: str, target: TargetConfig, defaults: Defaults -) -> ModelInfo: - """Return the effective model for a target, validated against required capabilities. - - If the target specifies an explicit model, it is validated to have all - required capabilities. Otherwise the type-appropriate default is tried - first; if it lacks a required capability the first capable model of the - same type is selected. - - Raises :class:`ValueError` if no suitable model can be found. - """ - from bulkgen.providers.models import ALL_MODELS - - required = infer_required_capabilities(target_name, target) - target_type = target_type_from_capabilities(required) - - if target.model is not None: - # Explicit model — look up and validate. - for model in ALL_MODELS: - if model.name == target.model: - missing = required - frozenset(model.capabilities) - if missing: - names = ", ".join(sorted(missing)) - msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}" - raise ValueError(msg) - return model - msg = f"Unknown model '{target.model}' for target '{target_name}'" - raise ValueError(msg) - - # No explicit model — try the default first, then fall back. - default_name = ( - defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model - ) - for model in ALL_MODELS: - if model.name == default_name: - if required <= frozenset(model.capabilities): - return model - break - - # Default lacks capabilities — find the first capable model of the same type. - model_type = "image" if target_type is TargetType.IMAGE else "text" - for model in ALL_MODELS: - if model.type == model_type and required <= frozenset(model.capabilities): - return model - - names = ", ".join(sorted(required)) - msg = ( - f"No model found for target '{target_name}' with required capabilities: {names}" - ) - raise ValueError(msg) - - def load_config(config_path: Path) -> ProjectConfig: """Load and validate a ``.bulkgen.yaml`` file.""" with config_path.open() as f: diff --git a/bulkgen/providers/__init__.py b/bulkgen/providers/__init__.py index 0495b11..c4ca3b7 100644 --- a/bulkgen/providers/__init__.py +++ b/bulkgen/providers/__init__.py @@ -12,6 +12,11 @@ from bulkgen.providers.models import ModelInfo class Provider(abc.ABC): """Abstract base for generation providers.""" + @staticmethod + @abc.abstractmethod + def get_provided_models() -> list[ModelInfo]: + """Return the models this provider supports.""" + @abc.abstractmethod async def generate( self, diff --git a/bulkgen/providers/image.py b/bulkgen/providers/blackforest.py similarity index 53% rename from bulkgen/providers/image.py rename to bulkgen/providers/blackforest.py index de92e85..e94e250 100644 --- a/bulkgen/providers/image.py +++ b/bulkgen/providers/blackforest.py @@ -11,7 +11,7 @@ import httpx from bulkgen.config import TargetConfig from bulkgen.providers import Provider from bulkgen.providers.bfl import BFLClient -from bulkgen.providers.models import ModelInfo +from bulkgen.providers.models import Capability, ModelInfo def _encode_image_b64(path: Path) -> str: @@ -45,7 +45,7 @@ def _add_reference_images( inputs[key] = _encode_image_b64(project_dir / ref_name) -class ImageProvider(Provider): +class BlackForestProvider(Provider): """Generates images via the BlackForestLabs API.""" _client: BFLClient @@ -53,6 +53,72 @@ class ImageProvider(Provider): def __init__(self, api_key: str) -> None: self._client = BFLClient(api_key=api_key) + @staticmethod + @override + def get_provided_models() -> list[ModelInfo]: + return [ + ModelInfo( + name="flux-dev", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE], + ), + ModelInfo( + name="flux-pro", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE], + ), + ModelInfo( + name="flux-pro-1.1", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE], + ), + ModelInfo( + name="flux-pro-1.1-ultra", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE], + ), + ModelInfo( + name="flux-2-pro", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES], + ), + ModelInfo( + name="flux-kontext-pro", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES], + ), + ModelInfo( + name="flux-pro-1.0-canny", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES], + ), + ModelInfo( + name="flux-pro-1.0-depth", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES], + ), + ModelInfo( + name="flux-pro-1.0-fill", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE], + ), + ModelInfo( + name="flux-pro-1.0-expand", + provider="BlackForestLabs", + type="image", + capabilities=[Capability.TEXT_TO_IMAGE], + ), + ] + @override async def generate( self, diff --git a/bulkgen/providers/text.py b/bulkgen/providers/mistral.py similarity index 75% rename from bulkgen/providers/text.py rename to bulkgen/providers/mistral.py index 564e2ea..c8be302 100644 --- a/bulkgen/providers/text.py +++ b/bulkgen/providers/mistral.py @@ -11,7 +11,7 @@ from mistralai import Mistral, models from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig from bulkgen.providers import Provider -from bulkgen.providers.models import ModelInfo +from bulkgen.providers.models import Capability, ModelInfo def _image_to_data_url(path: Path) -> str: @@ -21,7 +21,7 @@ def _image_to_data_url(path: Path) -> str: return f"data:{mime};base64,{b64}" -class TextProvider(Provider): +class MistralProvider(Provider): """Generates text via the Mistral API.""" _api_key: str @@ -29,6 +29,36 @@ class TextProvider(Provider): def __init__(self, api_key: str) -> None: self._api_key = api_key + @staticmethod + @override + def get_provided_models() -> list[ModelInfo]: + return [ + ModelInfo( + name="mistral-large-latest", + provider="Mistral", + type="text", + capabilities=[Capability.TEXT_GENERATION], + ), + ModelInfo( + name="mistral-small-latest", + provider="Mistral", + type="text", + capabilities=[Capability.TEXT_GENERATION], + ), + ModelInfo( + name="pixtral-large-latest", + provider="Mistral", + type="text", + capabilities=[Capability.TEXT_GENERATION, Capability.VISION], + ), + ModelInfo( + name="pixtral-12b-latest", + provider="Mistral", + type="text", + capabilities=[Capability.TEXT_GENERATION, Capability.VISION], + ), + ] + @override async def generate( self, @@ -98,6 +128,8 @@ def _build_multimodal_message( models.TextChunk(text=prompt), ] + from bulkgen.config import IMAGE_EXTENSIONS + for name in input_names: input_path = project_dir / name suffix = input_path.suffix.lower() diff --git a/bulkgen/providers/models.py b/bulkgen/providers/models.py index 0e65149..a014d0d 100644 --- a/bulkgen/providers/models.py +++ b/bulkgen/providers/models.py @@ -1,4 +1,4 @@ -"""Registry of supported models and their capabilities.""" +"""Model types and capability definitions for AI providers.""" from __future__ import annotations @@ -25,106 +25,3 @@ class ModelInfo: provider: str type: Literal["text", "image"] capabilities: list[Capability] - - -TEXT_MODELS: list[ModelInfo] = [ - ModelInfo( - name="mistral-large-latest", - provider="Mistral", - type="text", - capabilities=[Capability.TEXT_GENERATION], - ), - ModelInfo( - name="mistral-small-latest", - provider="Mistral", - type="text", - capabilities=[Capability.TEXT_GENERATION], - ), - ModelInfo( - name="pixtral-large-latest", - provider="Mistral", - type="text", - capabilities=[Capability.TEXT_GENERATION, Capability.VISION], - ), - ModelInfo( - name="pixtral-12b-latest", - provider="Mistral", - type="text", - capabilities=[Capability.TEXT_GENERATION, Capability.VISION], - ), -] - -IMAGE_MODELS: list[ModelInfo] = [ - ModelInfo( - name="flux-dev", - provider="BlackForestLabs", - type="image", - capabilities=[Capability.TEXT_TO_IMAGE], - ), - ModelInfo( - name="flux-pro", - provider="BlackForestLabs", - type="image", - capabilities=[Capability.TEXT_TO_IMAGE], - ), - ModelInfo( - name="flux-pro-1.1", - provider="BlackForestLabs", - type="image", - capabilities=[Capability.TEXT_TO_IMAGE], - ), - ModelInfo( - name="flux-pro-1.1-ultra", - provider="BlackForestLabs", - type="image", - capabilities=[Capability.TEXT_TO_IMAGE], - ), - ModelInfo( - name="flux-2-pro", - provider="BlackForestLabs", - type="image", - capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES], - ), - ModelInfo( - name="flux-kontext-pro", - provider="BlackForestLabs", - type="image", - capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES], - ), - ModelInfo( - name="flux-pro-1.0-canny", - provider="BlackForestLabs", - type="image", - capabilities=[ - Capability.TEXT_TO_IMAGE, - Capability.CONTROL_IMAGES, - ], - ), - ModelInfo( - name="flux-pro-1.0-depth", - provider="BlackForestLabs", - type="image", - capabilities=[ - Capability.TEXT_TO_IMAGE, - Capability.CONTROL_IMAGES, - ], - ), - ModelInfo( - name="flux-pro-1.0-fill", - provider="BlackForestLabs", - type="image", - capabilities=[ - Capability.TEXT_TO_IMAGE, - ], - ), - ModelInfo( - name="flux-pro-1.0-expand", - provider="BlackForestLabs", - type="image", - capabilities=[ - Capability.TEXT_TO_IMAGE, - ], - ), -] - -ALL_MODELS: list[ModelInfo] = TEXT_MODELS + IMAGE_MODELS diff --git a/bulkgen/providers/registry.py b/bulkgen/providers/registry.py new file mode 100644 index 0000000..20273dd --- /dev/null +++ b/bulkgen/providers/registry.py @@ -0,0 +1,16 @@ +"""Aggregates models from all registered providers.""" + +from __future__ import annotations + +from bulkgen.providers.models import ModelInfo + + +def get_all_models() -> list[ModelInfo]: + """Return the merged list of models from all providers.""" + from bulkgen.providers.blackforest import BlackForestProvider + from bulkgen.providers.mistral import MistralProvider + + return ( + MistralProvider.get_provided_models() + + BlackForestProvider.get_provided_models() + ) diff --git a/bulkgen/resolve.py b/bulkgen/resolve.py new file mode 100644 index 0000000..1e2b48b --- /dev/null +++ b/bulkgen/resolve.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from pathlib import Path + +from bulkgen.config import ( + IMAGE_EXTENSIONS, + TEXT_EXTENSIONS, + Defaults, + TargetConfig, + TargetType, + target_type_from_capabilities, +) +from bulkgen.providers.models import Capability, ModelInfo + + +def infer_required_capabilities( + target_name: str, target: TargetConfig +) -> frozenset[Capability]: + """Infer the capabilities a model must have based on filename and config. + + Raises :class:`ValueError` for unsupported file extensions. + """ + suffix = Path(target_name).suffix.lower() + caps: set[Capability] = set() + + if suffix in IMAGE_EXTENSIONS: + caps.add(Capability.TEXT_TO_IMAGE) + if target.reference_images: + caps.add(Capability.REFERENCE_IMAGES) + if target.control_images: + caps.add(Capability.CONTROL_IMAGES) + elif suffix in TEXT_EXTENSIONS: + caps.add(Capability.TEXT_GENERATION) + all_input_names = list(target.inputs) + list(target.reference_images) + if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names): + caps.add(Capability.VISION) + else: + msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'" + raise ValueError(msg) + + return frozenset(caps) + + +def resolve_model( + target_name: str, target: TargetConfig, defaults: Defaults +) -> ModelInfo: + """Return the effective model for a target, validated against required capabilities. + + If the target specifies an explicit model, it is validated to have all + required capabilities. Otherwise the type-appropriate default is tried + first; if it lacks a required capability the first capable model of the + same type is selected. + + Raises :class:`ValueError` if no suitable model can be found. + """ + from bulkgen.providers.registry import get_all_models + + all_models = get_all_models() + required = infer_required_capabilities(target_name, target) + target_type = target_type_from_capabilities(required) + + if target.model is not None: + # Explicit model — look up and validate. + for model in all_models: + if model.name == target.model: + missing = required - frozenset(model.capabilities) + if missing: + names = ", ".join(sorted(missing)) + msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}" + raise ValueError(msg) + return model + msg = f"Unknown model '{target.model}' for target '{target_name}'" + raise ValueError(msg) + + # No explicit model — try the default first, then fall back. + default_name = ( + defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model + ) + for model in all_models: + if model.name == default_name: + if required <= frozenset(model.capabilities): + return model + break + + # Default lacks capabilities — find the first capable model of the same type. + model_type = "image" if target_type is TargetType.IMAGE else "text" + for model in all_models: + if model.type == model_type and required <= frozenset(model.capabilities): + return model + + names = ", ".join(sorted(required)) + msg = ( + f"No model found for target '{target_name}' with required capabilities: {names}" + ) + raise ValueError(msg) diff --git a/tests/test_builder.py b/tests/test_builder.py index ae96868..0968cf9 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -27,6 +27,11 @@ WriteConfig = Callable[[dict[str, object]], ProjectConfig] class FakeProvider(Provider): """A provider that writes a marker file instead of calling an API.""" + @staticmethod + @override + def get_provided_models() -> list[ModelInfo]: + return [] + @override async def generate( self, @@ -43,6 +48,11 @@ class FakeProvider(Provider): class FailingProvider(Provider): """A provider that always raises.""" + @staticmethod + @override + def get_provided_models() -> list[ModelInfo]: + return [] + @override async def generate( self, diff --git a/tests/test_config.py b/tests/test_config.py index b897860..c48d6df 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,14 +7,7 @@ from pathlib import Path import pytest import yaml -from bulkgen.config import ( - Defaults, - TargetConfig, - infer_required_capabilities, - load_config, - resolve_model, -) -from bulkgen.providers.models import Capability +from bulkgen.config import load_config class TestLoadConfig: @@ -86,136 +79,3 @@ class TestLoadConfig: with pytest.raises(Exception): _ = load_config(config_path) - - -class TestInferRequiredCapabilities: - """Test capability inference from file extensions and target config.""" - - def test_plain_image(self) -> None: - target = TargetConfig(prompt="x") - assert infer_required_capabilities("out.png", target) == frozenset( - {Capability.TEXT_TO_IMAGE} - ) - - @pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"]) - def test_image_extensions(self, name: str) -> None: - target = TargetConfig(prompt="x") - caps = infer_required_capabilities(name, target) - assert Capability.TEXT_TO_IMAGE in caps - - @pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"]) - def test_case_insensitive(self, name: str) -> None: - target = TargetConfig(prompt="x") - caps = infer_required_capabilities(name, target) - assert Capability.TEXT_TO_IMAGE in caps - - def test_image_with_reference_images(self) -> None: - target = TargetConfig(prompt="x", reference_images=["ref.png"]) - assert infer_required_capabilities("out.png", target) == frozenset( - {Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES} - ) - - def test_image_with_control_images(self) -> None: - target = TargetConfig(prompt="x", control_images=["ctrl.png"]) - assert infer_required_capabilities("out.png", target) == frozenset( - {Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES} - ) - - def test_image_with_both(self) -> None: - target = TargetConfig( - prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"] - ) - assert infer_required_capabilities("out.png", target) == frozenset( - { - Capability.TEXT_TO_IMAGE, - Capability.REFERENCE_IMAGES, - Capability.CONTROL_IMAGES, - } - ) - - @pytest.mark.parametrize("name", ["doc.md", "doc.txt"]) - def test_text_extensions(self, name: str) -> None: - target = TargetConfig(prompt="x") - caps = infer_required_capabilities(name, target) - assert caps == frozenset({Capability.TEXT_GENERATION}) - - def test_text_with_text_inputs(self) -> None: - target = TargetConfig(prompt="x", inputs=["data.txt"]) - assert infer_required_capabilities("out.md", target) == frozenset( - {Capability.TEXT_GENERATION} - ) - - def test_text_with_image_input(self) -> None: - target = TargetConfig(prompt="x", inputs=["photo.png"]) - assert infer_required_capabilities("out.txt", target) == frozenset( - {Capability.TEXT_GENERATION, Capability.VISION} - ) - - def test_text_with_image_reference(self) -> None: - target = TargetConfig(prompt="x", reference_images=["ref.jpg"]) - assert infer_required_capabilities("out.md", target) == frozenset( - {Capability.TEXT_GENERATION, Capability.VISION} - ) - - def test_unsupported_extension_raises(self) -> None: - target = TargetConfig(prompt="x") - with pytest.raises(ValueError, match="unsupported extension"): - _ = infer_required_capabilities("data.csv", target) - - def test_no_extension_raises(self) -> None: - target = TargetConfig(prompt="x") - with pytest.raises(ValueError, match="unsupported extension"): - _ = infer_required_capabilities("Makefile", target) - - -class TestResolveModel: - """Test model resolution with capability validation.""" - - def test_explicit_model_wins(self) -> None: - target = TargetConfig(prompt="x", model="mistral-small-latest") - result = resolve_model("out.txt", target, Defaults()) - assert result.name == "mistral-small-latest" - - def test_default_text_model(self) -> None: - target = TargetConfig(prompt="x") - defaults = Defaults(text_model="mistral-large-latest") - result = resolve_model("out.md", target, defaults) - assert result.name == "mistral-large-latest" - - def test_default_image_model(self) -> None: - target = TargetConfig(prompt="x") - defaults = Defaults(image_model="flux-dev") - result = resolve_model("out.png", target, defaults) - assert result.name == "flux-dev" - - def test_unknown_model_raises(self) -> None: - target = TargetConfig(prompt="x", model="nonexistent-model") - with pytest.raises(ValueError, match="Unknown model"): - _ = resolve_model("out.txt", target, Defaults()) - - def test_explicit_model_missing_capability_raises(self) -> None: - # flux-dev does not support reference images - target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"]) - with pytest.raises(ValueError, match="lacks required capabilities"): - _ = resolve_model("out.png", target, Defaults()) - - def test_default_fallback_for_reference_images(self) -> None: - # Default flux-dev lacks reference_images, should fall back to a capable model - target = TargetConfig(prompt="x", reference_images=["r.png"]) - defaults = Defaults(image_model="flux-dev") - result = resolve_model("out.png", target, defaults) - assert Capability.REFERENCE_IMAGES in result.capabilities - - def test_default_fallback_for_vision(self) -> None: - # Default mistral-large-latest lacks vision, should fall back to a pixtral model - target = TargetConfig(prompt="x", inputs=["photo.png"]) - defaults = Defaults(text_model="mistral-large-latest") - result = resolve_model("out.txt", target, defaults) - assert Capability.VISION in result.capabilities - - def test_default_preferred_when_capable(self) -> None: - # Default flux-2-pro already supports reference_images, should be used directly - target = TargetConfig(prompt="x", reference_images=["r.png"]) - defaults = Defaults(image_model="flux-2-pro") - result = resolve_model("out.png", target, defaults) - assert result.name == "flux-2-pro" diff --git a/tests/test_providers.py b/tests/test_providers.py index 4f00e72..65849d7 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -14,17 +14,18 @@ import pytest from bulkgen.config import TargetConfig from bulkgen.providers.bfl import BFLResult -from bulkgen.providers.image import ImageProvider -from bulkgen.providers.image import ( +from bulkgen.providers.blackforest import BlackForestProvider +from bulkgen.providers.blackforest import ( _encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage] ) -from bulkgen.providers.models import ALL_MODELS, ModelInfo -from bulkgen.providers.text import TextProvider +from bulkgen.providers.mistral import MistralProvider +from bulkgen.providers.models import ModelInfo +from bulkgen.providers.registry import get_all_models def _model(name: str) -> ModelInfo: """Look up a ModelInfo by name.""" - for m in ALL_MODELS: + for m in get_all_models(): if m.name == name: return m msg = f"Unknown test model: {name}" @@ -69,8 +70,8 @@ def _make_text_response(content: str | None) -> MagicMock: return response -class TestImageProvider: - """Test ImageProvider with mocked BFL client and HTTP.""" +class TestBlackForestProvider: + """Test BlackForestProvider with mocked BFL client and HTTP.""" @pytest.fixture def image_bytes(self) -> bytes: @@ -83,13 +84,13 @@ class TestImageProvider: bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( - patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, + patch("bulkgen.providers.blackforest.BFLClient") as mock_cls, + patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls, ): mock_cls.return_value.generate = AsyncMock(return_value=bfl_result) mock_http_cls.return_value = mock_http - provider = ImageProvider(api_key="test-key") + provider = BlackForestProvider(api_key="test-key") await provider.generate( target_name="out.png", target_config=target_config, @@ -109,14 +110,14 @@ class TestImageProvider: bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( - patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, + patch("bulkgen.providers.blackforest.BFLClient") as mock_cls, + patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls, ): mock_generate = AsyncMock(return_value=bfl_result) mock_cls.return_value.generate = mock_generate mock_http_cls.return_value = mock_http - provider = ImageProvider(api_key="test-key") + provider = BlackForestProvider(api_key="test-key") await provider.generate( target_name="banner.png", target_config=target_config, @@ -140,14 +141,14 @@ class TestImageProvider: bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( - patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, + patch("bulkgen.providers.blackforest.BFLClient") as mock_cls, + patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls, ): mock_generate = AsyncMock(return_value=bfl_result) mock_cls.return_value.generate = mock_generate mock_http_cls.return_value = mock_http - provider = ImageProvider(api_key="test-key") + provider = BlackForestProvider(api_key="test-key") await provider.generate( target_name="out.png", target_config=target_config, @@ -175,14 +176,14 @@ class TestImageProvider: bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( - patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, + patch("bulkgen.providers.blackforest.BFLClient") as mock_cls, + patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls, ): mock_generate = AsyncMock(return_value=bfl_result) mock_cls.return_value.generate = mock_generate mock_http_cls.return_value = mock_http - provider = ImageProvider(api_key="test-key") + provider = BlackForestProvider(api_key="test-key") await provider.generate( target_name="out.png", target_config=target_config, @@ -206,14 +207,14 @@ class TestImageProvider: bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( - patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, + patch("bulkgen.providers.blackforest.BFLClient") as mock_cls, + patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls, ): mock_generate = AsyncMock(return_value=bfl_result) mock_cls.return_value.generate = mock_generate mock_http_cls.return_value = mock_http - provider = ImageProvider(api_key="test-key") + provider = BlackForestProvider(api_key="test-key") await provider.generate( target_name="out.png", target_config=target_config, @@ -229,14 +230,14 @@ class TestImageProvider: async def test_image_no_sample_url_raises(self, project_dir: Path) -> None: target_config = TargetConfig(prompt="x") - with patch("bulkgen.providers.image.BFLClient") as mock_cls: + with patch("bulkgen.providers.blackforest.BFLClient") as mock_cls: from bulkgen.providers.bfl import BFLError mock_cls.return_value.generate = AsyncMock( side_effect=BFLError("BFL task test ready but no sample URL: {}") ) - provider = ImageProvider(api_key="test-key") + provider = BlackForestProvider(api_key="test-key") with pytest.raises(BFLError, match="no sample URL"): await provider.generate( target_name="fail.png", @@ -255,17 +256,17 @@ class TestImageProvider: assert base64.b64decode(encoded) == data -class TestTextProvider: - """Test TextProvider with mocked Mistral client.""" +class TestMistralProvider: + """Test MistralProvider with mocked Mistral client.""" async def test_basic_text_generation(self, project_dir: Path) -> None: target_config = TargetConfig(prompt="Write a poem") response = _make_text_response("Roses are red...") - with patch("bulkgen.providers.text.Mistral") as mock_cls: + with patch("bulkgen.providers.mistral.Mistral") as mock_cls: mock_cls.return_value = _make_mistral_mock(response) - provider = TextProvider(api_key="test-key") + provider = MistralProvider(api_key="test-key") await provider.generate( target_name="poem.txt", target_config=target_config, @@ -283,11 +284,11 @@ class TestTextProvider: target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"]) response = _make_text_response("Summary: ...") - with patch("bulkgen.providers.text.Mistral") as mock_cls: + with patch("bulkgen.providers.mistral.Mistral") as mock_cls: mock_client = _make_mistral_mock(response) mock_cls.return_value = mock_client - provider = TextProvider(api_key="test-key") + provider = MistralProvider(api_key="test-key") await provider.generate( target_name="summary.md", target_config=target_config, @@ -307,11 +308,11 @@ class TestTextProvider: target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"]) response = _make_text_response("A beautiful photo") - with patch("bulkgen.providers.text.Mistral") as mock_cls: + with patch("bulkgen.providers.mistral.Mistral") as mock_cls: mock_client = _make_mistral_mock(response) mock_cls.return_value = mock_client - provider = TextProvider(api_key="test-key") + provider = MistralProvider(api_key="test-key") await provider.generate( target_name="desc.txt", target_config=target_config, @@ -332,10 +333,10 @@ class TestTextProvider: response = MagicMock() response.choices = [] - with patch("bulkgen.providers.text.Mistral") as mock_cls: + with patch("bulkgen.providers.mistral.Mistral") as mock_cls: mock_cls.return_value = _make_mistral_mock(response) - provider = TextProvider(api_key="test-key") + provider = MistralProvider(api_key="test-key") with pytest.raises(RuntimeError, match="no choices"): await provider.generate( target_name="fail.txt", @@ -349,10 +350,10 @@ class TestTextProvider: target_config = TargetConfig(prompt="x") response = _make_text_response(None) - with patch("bulkgen.providers.text.Mistral") as mock_cls: + with patch("bulkgen.providers.mistral.Mistral") as mock_cls: mock_cls.return_value = _make_mistral_mock(response) - provider = TextProvider(api_key="test-key") + provider = MistralProvider(api_key="test-key") with pytest.raises(RuntimeError, match="empty content"): await provider.generate( target_name="fail.txt", @@ -372,11 +373,11 @@ class TestTextProvider: ) response = _make_text_response("Combined") - with patch("bulkgen.providers.text.Mistral") as mock_cls: + with patch("bulkgen.providers.mistral.Mistral") as mock_cls: mock_client = _make_mistral_mock(response) mock_cls.return_value = mock_client - provider = TextProvider(api_key="test-key") + provider = MistralProvider(api_key="test-key") await provider.generate( target_name="out.md", target_config=target_config, @@ -403,11 +404,11 @@ class TestTextProvider: ) response = _make_text_response("A stylized image") - with patch("bulkgen.providers.text.Mistral") as mock_cls: + with patch("bulkgen.providers.mistral.Mistral") as mock_cls: mock_client = _make_mistral_mock(response) mock_cls.return_value = mock_client - provider = TextProvider(api_key="test-key") + provider = MistralProvider(api_key="test-key") await provider.generate( target_name="desc.txt", target_config=target_config, diff --git a/tests/test_resolve.py b/tests/test_resolve.py new file mode 100644 index 0000000..3c33e6e --- /dev/null +++ b/tests/test_resolve.py @@ -0,0 +1,145 @@ +"""Integration tests for bulkgen.config.""" + +from __future__ import annotations + +import pytest + +from bulkgen.config import ( + Defaults, + TargetConfig, +) +from bulkgen.providers.models import Capability +from bulkgen.resolve import infer_required_capabilities, resolve_model + + +class TestInferRequiredCapabilities: + """Test capability inference from file extensions and target config.""" + + def test_plain_image(self) -> None: + target = TargetConfig(prompt="x") + assert infer_required_capabilities("out.png", target) == frozenset( + {Capability.TEXT_TO_IMAGE} + ) + + @pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"]) + def test_image_extensions(self, name: str) -> None: + target = TargetConfig(prompt="x") + caps = infer_required_capabilities(name, target) + assert Capability.TEXT_TO_IMAGE in caps + + @pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"]) + def test_case_insensitive(self, name: str) -> None: + target = TargetConfig(prompt="x") + caps = infer_required_capabilities(name, target) + assert Capability.TEXT_TO_IMAGE in caps + + def test_image_with_reference_images(self) -> None: + target = TargetConfig(prompt="x", reference_images=["ref.png"]) + assert infer_required_capabilities("out.png", target) == frozenset( + {Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES} + ) + + def test_image_with_control_images(self) -> None: + target = TargetConfig(prompt="x", control_images=["ctrl.png"]) + assert infer_required_capabilities("out.png", target) == frozenset( + {Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES} + ) + + def test_image_with_both(self) -> None: + target = TargetConfig( + prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"] + ) + assert infer_required_capabilities("out.png", target) == frozenset( + { + Capability.TEXT_TO_IMAGE, + Capability.REFERENCE_IMAGES, + Capability.CONTROL_IMAGES, + } + ) + + @pytest.mark.parametrize("name", ["doc.md", "doc.txt"]) + def test_text_extensions(self, name: str) -> None: + target = TargetConfig(prompt="x") + caps = infer_required_capabilities(name, target) + assert caps == frozenset({Capability.TEXT_GENERATION}) + + def test_text_with_text_inputs(self) -> None: + target = TargetConfig(prompt="x", inputs=["data.txt"]) + assert infer_required_capabilities("out.md", target) == frozenset( + {Capability.TEXT_GENERATION} + ) + + def test_text_with_image_input(self) -> None: + target = TargetConfig(prompt="x", inputs=["photo.png"]) + assert infer_required_capabilities("out.txt", target) == frozenset( + {Capability.TEXT_GENERATION, Capability.VISION} + ) + + def test_text_with_image_reference(self) -> None: + target = TargetConfig(prompt="x", reference_images=["ref.jpg"]) + assert infer_required_capabilities("out.md", target) == frozenset( + {Capability.TEXT_GENERATION, Capability.VISION} + ) + + def test_unsupported_extension_raises(self) -> None: + target = TargetConfig(prompt="x") + with pytest.raises(ValueError, match="unsupported extension"): + _ = infer_required_capabilities("data.csv", target) + + def test_no_extension_raises(self) -> None: + target = TargetConfig(prompt="x") + with pytest.raises(ValueError, match="unsupported extension"): + _ = infer_required_capabilities("Makefile", target) + + +class TestResolveModel: + """Test model resolution with capability validation.""" + + def test_explicit_model_wins(self) -> None: + target = TargetConfig(prompt="x", model="mistral-small-latest") + result = resolve_model("out.txt", target, Defaults()) + assert result.name == "mistral-small-latest" + + def test_default_text_model(self) -> None: + target = TargetConfig(prompt="x") + defaults = Defaults(text_model="mistral-large-latest") + result = resolve_model("out.md", target, defaults) + assert result.name == "mistral-large-latest" + + def test_default_image_model(self) -> None: + target = TargetConfig(prompt="x") + defaults = Defaults(image_model="flux-dev") + result = resolve_model("out.png", target, defaults) + assert result.name == "flux-dev" + + def test_unknown_model_raises(self) -> None: + target = TargetConfig(prompt="x", model="nonexistent-model") + with pytest.raises(ValueError, match="Unknown model"): + _ = resolve_model("out.txt", target, Defaults()) + + def test_explicit_model_missing_capability_raises(self) -> None: + # flux-dev does not support reference images + target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"]) + with pytest.raises(ValueError, match="lacks required capabilities"): + _ = resolve_model("out.png", target, Defaults()) + + def test_default_fallback_for_reference_images(self) -> None: + # Default flux-dev lacks reference_images, should fall back to a capable model + target = TargetConfig(prompt="x", reference_images=["r.png"]) + defaults = Defaults(image_model="flux-dev") + result = resolve_model("out.png", target, defaults) + assert Capability.REFERENCE_IMAGES in result.capabilities + + def test_default_fallback_for_vision(self) -> None: + # Default mistral-large-latest lacks vision, should fall back to a pixtral model + target = TargetConfig(prompt="x", inputs=["photo.png"]) + defaults = Defaults(text_model="mistral-large-latest") + result = resolve_model("out.txt", target, defaults) + assert Capability.VISION in result.capabilities + + def test_default_preferred_when_capable(self) -> None: + # Default flux-2-pro already supports reference_images, should be used directly + target = TargetConfig(prompt="x", reference_images=["r.png"]) + defaults = Defaults(image_model="flux-2-pro") + result = resolve_model("out.png", target, defaults) + assert result.name == "flux-2-pro"