diff --git a/bulkgen/builder.py b/bulkgen/builder.py index 9320bfa..9439353 100644 --- a/bulkgen/builder.py +++ b/bulkgen/builder.py @@ -12,8 +12,9 @@ from pathlib import Path from bulkgen.config import ( ProjectConfig, TargetType, - infer_target_type, + infer_required_capabilities, resolve_model, + target_type_from_capabilities, ) from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target from bulkgen.providers import Provider @@ -120,8 +121,9 @@ async def _build_single_target( ) -> None: """Build a single target by dispatching to the appropriate provider.""" target_cfg = config.targets[target_name] - target_type = infer_target_type(target_name) model_info = resolve_model(target_name, target_cfg, config.defaults) + required = infer_required_capabilities(target_name, target_cfg) + target_type = target_type_from_capabilities(required) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) provider = providers[target_type] @@ -180,7 +182,7 @@ async def run_build( continue if _is_dirty(name, config, project_dir, state): - if not _has_provider(name, providers, result, on_progress): + if not _has_provider(name, config, providers, result, on_progress): continue dirty_targets.append(name) else: @@ -236,12 +238,15 @@ def _is_dirty( def _has_provider( target_name: str, + config: ProjectConfig, providers: dict[TargetType, Provider], result: BuildResult, on_progress: ProgressCallback = _noop_callback, ) -> bool: """Check that the required provider is available; record failure if not.""" - target_type = infer_target_type(target_name) + target_cfg = config.targets[target_name] + required = infer_required_capabilities(target_name, target_cfg) + target_type = target_type_from_capabilities(required) if target_type not in providers: env_var = ( "BFL_API_KEY" if target_type is TargetType.IMAGE else "MISTRAL_API_KEY" diff --git a/bulkgen/config.py b/bulkgen/config.py index 48d9a83..8d0f9d9 100644 --- a/bulkgen/config.py +++ b/bulkgen/config.py @@ -10,7 +10,7 @@ import yaml from pydantic import BaseModel, model_validator if TYPE_CHECKING: - from bulkgen.providers.models import ModelInfo + from bulkgen.providers.models import Capability, ModelInfo class TargetType(enum.Enum): @@ -57,39 +57,95 @@ class ProjectConfig(BaseModel): return self -def infer_target_type(target_name: str) -> TargetType: - """Infer whether a target produces an image or text from its file extension.""" +def infer_required_capabilities( + target_name: str, target: TargetConfig +) -> frozenset[Capability]: + """Infer the capabilities a model must have based on filename and config. + + Raises :class:`ValueError` for unsupported file extensions. + """ + from bulkgen.providers.models import Capability + suffix = Path(target_name).suffix.lower() + caps: set[Capability] = set() + if suffix in IMAGE_EXTENSIONS: + caps.add(Capability.TEXT_TO_IMAGE) + if target.reference_images: + caps.add(Capability.REFERENCE_IMAGES) + if target.control_images: + caps.add(Capability.CONTROL_IMAGES) + elif suffix in TEXT_EXTENSIONS: + caps.add(Capability.TEXT_GENERATION) + all_input_names = list(target.inputs) + list(target.reference_images) + if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names): + caps.add(Capability.VISION) + else: + msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'" + raise ValueError(msg) + + return frozenset(caps) + + +def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType: + """Derive the target type from a set of required capabilities.""" + from bulkgen.providers.models import Capability + + if Capability.TEXT_TO_IMAGE in capabilities: return TargetType.IMAGE - if suffix in TEXT_EXTENSIONS: - return TargetType.TEXT - msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'" - raise ValueError(msg) + return TargetType.TEXT def resolve_model( target_name: str, target: TargetConfig, defaults: Defaults ) -> ModelInfo: - """Return the effective model for a target (explicit or default by type). + """Return the effective model for a target, validated against required capabilities. - Raises :class:`ValueError` if the resolved model name is not in the registry. + If the target specifies an explicit model, it is validated to have all + required capabilities. Otherwise the type-appropriate default is tried + first; if it lacks a required capability the first capable model of the + same type is selected. + + Raises :class:`ValueError` if no suitable model can be found. """ from bulkgen.providers.models import ALL_MODELS + required = infer_required_capabilities(target_name, target) + target_type = target_type_from_capabilities(required) + if target.model is not None: - model_name = target.model - else: - target_type = infer_target_type(target_name) - model_name = ( - defaults.image_model - if target_type is TargetType.IMAGE - else defaults.text_model - ) + # Explicit model — look up and validate. + for model in ALL_MODELS: + if model.name == target.model: + missing = required - frozenset(model.capabilities) + if missing: + names = ", ".join(sorted(missing)) + msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}" + raise ValueError(msg) + return model + msg = f"Unknown model '{target.model}' for target '{target_name}'" + raise ValueError(msg) + + # No explicit model — try the default first, then fall back. + default_name = ( + defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model + ) for model in ALL_MODELS: - if model.name == model_name: + if model.name == default_name: + if required <= frozenset(model.capabilities): + return model + break + + # Default lacks capabilities — find the first capable model of the same type. + model_type = "image" if target_type is TargetType.IMAGE else "text" + for model in ALL_MODELS: + if model.type == model_type and required <= frozenset(model.capabilities): return model - msg = f"Unknown model '{model_name}' for target '{target_name}'" + + names = ", ".join(sorted(required)) + msg = ( + f"No model found for target '{target_name}' with required capabilities: {names}" + ) raise ValueError(msg) diff --git a/tests/test_config.py b/tests/test_config.py index 17fdf7d..b897860 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,11 +10,11 @@ import yaml from bulkgen.config import ( Defaults, TargetConfig, - TargetType, - infer_target_type, + infer_required_capabilities, load_config, resolve_model, ) +from bulkgen.providers.models import Capability class TestLoadConfig: @@ -88,34 +88,88 @@ class TestLoadConfig: _ = load_config(config_path) -class TestInferTargetType: - """Test target type inference from file extensions.""" +class TestInferRequiredCapabilities: + """Test capability inference from file extensions and target config.""" - @pytest.mark.parametrize( - "name", ["photo.png", "photo.jpg", "photo.jpeg", "photo.webp"] - ) + def test_plain_image(self) -> None: + target = TargetConfig(prompt="x") + assert infer_required_capabilities("out.png", target) == frozenset( + {Capability.TEXT_TO_IMAGE} + ) + + @pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"]) def test_image_extensions(self, name: str) -> None: - assert infer_target_type(name) is TargetType.IMAGE + target = TargetConfig(prompt="x") + caps = infer_required_capabilities(name, target) + assert Capability.TEXT_TO_IMAGE in caps - @pytest.mark.parametrize("name", ["PHOTO.PNG", "PHOTO.JPG"]) + @pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"]) def test_case_insensitive(self, name: str) -> None: - assert infer_target_type(name) is TargetType.IMAGE + target = TargetConfig(prompt="x") + caps = infer_required_capabilities(name, target) + assert Capability.TEXT_TO_IMAGE in caps + + def test_image_with_reference_images(self) -> None: + target = TargetConfig(prompt="x", reference_images=["ref.png"]) + assert infer_required_capabilities("out.png", target) == frozenset( + {Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES} + ) + + def test_image_with_control_images(self) -> None: + target = TargetConfig(prompt="x", control_images=["ctrl.png"]) + assert infer_required_capabilities("out.png", target) == frozenset( + {Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES} + ) + + def test_image_with_both(self) -> None: + target = TargetConfig( + prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"] + ) + assert infer_required_capabilities("out.png", target) == frozenset( + { + Capability.TEXT_TO_IMAGE, + Capability.REFERENCE_IMAGES, + Capability.CONTROL_IMAGES, + } + ) @pytest.mark.parametrize("name", ["doc.md", "doc.txt"]) def test_text_extensions(self, name: str) -> None: - assert infer_target_type(name) is TargetType.TEXT + target = TargetConfig(prompt="x") + caps = infer_required_capabilities(name, target) + assert caps == frozenset({Capability.TEXT_GENERATION}) + + def test_text_with_text_inputs(self) -> None: + target = TargetConfig(prompt="x", inputs=["data.txt"]) + assert infer_required_capabilities("out.md", target) == frozenset( + {Capability.TEXT_GENERATION} + ) + + def test_text_with_image_input(self) -> None: + target = TargetConfig(prompt="x", inputs=["photo.png"]) + assert infer_required_capabilities("out.txt", target) == frozenset( + {Capability.TEXT_GENERATION, Capability.VISION} + ) + + def test_text_with_image_reference(self) -> None: + target = TargetConfig(prompt="x", reference_images=["ref.jpg"]) + assert infer_required_capabilities("out.md", target) == frozenset( + {Capability.TEXT_GENERATION, Capability.VISION} + ) def test_unsupported_extension_raises(self) -> None: + target = TargetConfig(prompt="x") with pytest.raises(ValueError, match="unsupported extension"): - _ = infer_target_type("data.csv") + _ = infer_required_capabilities("data.csv", target) def test_no_extension_raises(self) -> None: + target = TargetConfig(prompt="x") with pytest.raises(ValueError, match="unsupported extension"): - _ = infer_target_type("Makefile") + _ = infer_required_capabilities("Makefile", target) class TestResolveModel: - """Test model resolution (explicit vs. default).""" + """Test model resolution with capability validation.""" def test_explicit_model_wins(self) -> None: target = TargetConfig(prompt="x", model="mistral-small-latest") @@ -138,3 +192,30 @@ class TestResolveModel: target = TargetConfig(prompt="x", model="nonexistent-model") with pytest.raises(ValueError, match="Unknown model"): _ = resolve_model("out.txt", target, Defaults()) + + def test_explicit_model_missing_capability_raises(self) -> None: + # flux-dev does not support reference images + target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"]) + with pytest.raises(ValueError, match="lacks required capabilities"): + _ = resolve_model("out.png", target, Defaults()) + + def test_default_fallback_for_reference_images(self) -> None: + # Default flux-dev lacks reference_images, should fall back to a capable model + target = TargetConfig(prompt="x", reference_images=["r.png"]) + defaults = Defaults(image_model="flux-dev") + result = resolve_model("out.png", target, defaults) + assert Capability.REFERENCE_IMAGES in result.capabilities + + def test_default_fallback_for_vision(self) -> None: + # Default mistral-large-latest lacks vision, should fall back to a pixtral model + target = TargetConfig(prompt="x", inputs=["photo.png"]) + defaults = Defaults(text_model="mistral-large-latest") + result = resolve_model("out.txt", target, defaults) + assert Capability.VISION in result.capabilities + + def test_default_preferred_when_capable(self) -> None: + # Default flux-2-pro already supports reference_images, should be used directly + target = TargetConfig(prompt="x", reference_images=["r.png"]) + defaults = Defaults(image_model="flux-2-pro") + result = resolve_model("out.png", target, defaults) + assert result.name == "flux-2-pro"