diff --git a/bulkgen/builder.py b/bulkgen/builder.py index 5fcc981..9320bfa 100644 --- a/bulkgen/builder.py +++ b/bulkgen/builder.py @@ -121,7 +121,7 @@ async def _build_single_target( """Build a single target by dispatching to the appropriate provider.""" target_cfg = config.targets[target_name] target_type = infer_target_type(target_name) - model = resolve_model(target_name, target_cfg, config.defaults) + model_info = resolve_model(target_name, target_cfg, config.defaults) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) provider = providers[target_type] @@ -129,7 +129,7 @@ async def _build_single_target( target_name=target_name, target_config=target_cfg, resolved_prompt=resolved_prompt, - resolved_model=model, + resolved_model=model_info, project_dir=project_dir, ) @@ -218,7 +218,7 @@ def _is_dirty( ) -> bool: """Check if a target needs rebuilding.""" target_cfg = config.targets[target_name] - model = resolve_model(target_name, target_cfg, config.defaults) + model_info = resolve_model(target_name, target_cfg, config.defaults) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) dep_files = _collect_dep_files(target_name, config, project_dir) extra = _collect_extra_params(target_name, config) @@ -226,7 +226,7 @@ def _is_dirty( return is_target_dirty( target_name, resolved_prompt=resolved_prompt, - model=model, + model=model_info.name, dep_files=dep_files, extra_params=extra, state=state, @@ -286,7 +286,7 @@ def _process_outcomes( on_progress(BuildEvent.TARGET_FAILED, name, str(error)) else: target_cfg = config.targets[name] - model = resolve_model(name, target_cfg, config.defaults) + model_info = resolve_model(name, target_cfg, config.defaults) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) dep_files = _collect_dep_files(name, config, project_dir) extra = _collect_extra_params(name, config) @@ -294,7 +294,7 @@ def _process_outcomes( record_target_state( name, resolved_prompt=resolved_prompt, - model=model, + model=model_info.name, dep_files=dep_files, extra_params=extra, state=state, diff --git a/bulkgen/config.py b/bulkgen/config.py index 136e43a..48d9a83 100644 --- a/bulkgen/config.py +++ b/bulkgen/config.py @@ -4,11 +4,14 @@ from __future__ import annotations import enum from pathlib import Path -from typing import Self +from typing import TYPE_CHECKING, Self import yaml from pydantic import BaseModel, model_validator +if TYPE_CHECKING: + from bulkgen.providers.models import ModelInfo + class TargetType(enum.Enum): """The kind of artifact a target produces.""" @@ -65,14 +68,29 @@ def infer_target_type(target_name: str) -> TargetType: raise ValueError(msg) -def resolve_model(target_name: str, target: TargetConfig, defaults: Defaults) -> str: - """Return the effective model for a target (explicit or default by type).""" +def resolve_model( + target_name: str, target: TargetConfig, defaults: Defaults +) -> ModelInfo: + """Return the effective model for a target (explicit or default by type). + + Raises :class:`ValueError` if the resolved model name is not in the registry. + """ + from bulkgen.providers.models import ALL_MODELS + if target.model is not None: - return target.model - target_type = infer_target_type(target_name) - if target_type is TargetType.IMAGE: - return defaults.image_model - return defaults.text_model + model_name = target.model + else: + target_type = infer_target_type(target_name) + model_name = ( + defaults.image_model + if target_type is TargetType.IMAGE + else defaults.text_model + ) + for model in ALL_MODELS: + if model.name == model_name: + return model + msg = f"Unknown model '{model_name}' for target '{target_name}'" + raise ValueError(msg) def load_config(config_path: Path) -> ProjectConfig: diff --git a/bulkgen/providers/__init__.py b/bulkgen/providers/__init__.py index 047aed1..0495b11 100644 --- a/bulkgen/providers/__init__.py +++ b/bulkgen/providers/__init__.py @@ -6,6 +6,7 @@ import abc from pathlib import Path from bulkgen.config import TargetConfig +from bulkgen.providers.models import ModelInfo class Provider(abc.ABC): @@ -17,7 +18,7 @@ class Provider(abc.ABC): target_name: str, target_config: TargetConfig, resolved_prompt: str, - resolved_model: str, + resolved_model: ModelInfo, project_dir: Path, ) -> None: """Generate the target artifact and write it to *project_dir / target_name*. @@ -26,6 +27,6 @@ class Provider(abc.ABC): target_name: Output filename (relative to project_dir). target_config: The parsed target configuration. resolved_prompt: The fully resolved prompt text. - resolved_model: The resolved model name. + resolved_model: The resolved model information. project_dir: The project working directory. """ diff --git a/bulkgen/providers/image.py b/bulkgen/providers/image.py index 9231e7c..de92e85 100644 --- a/bulkgen/providers/image.py +++ b/bulkgen/providers/image.py @@ -11,6 +11,7 @@ import httpx from bulkgen.config import TargetConfig from bulkgen.providers import Provider from bulkgen.providers.bfl import BFLClient +from bulkgen.providers.models import ModelInfo def _encode_image_b64(path: Path) -> str: @@ -58,7 +59,7 @@ class ImageProvider(Provider): target_name: str, target_config: TargetConfig, resolved_prompt: str, - resolved_model: str, + resolved_model: ModelInfo, project_dir: Path, ) -> None: output_path = project_dir / target_name @@ -72,14 +73,14 @@ class ImageProvider(Provider): if target_config.reference_images: _add_reference_images( - inputs, target_config.reference_images, resolved_model, project_dir + inputs, target_config.reference_images, resolved_model.name, project_dir ) for control_name in target_config.control_images: ctrl_path = project_dir / control_name inputs["control_image"] = _encode_image_b64(ctrl_path) - result = await self._client.generate(resolved_model, inputs) + result = await self._client.generate(resolved_model.name, inputs) async with httpx.AsyncClient() as http: response = await http.get(result.sample_url) diff --git a/bulkgen/providers/text.py b/bulkgen/providers/text.py index 8f29e5f..564e2ea 100644 --- a/bulkgen/providers/text.py +++ b/bulkgen/providers/text.py @@ -11,6 +11,7 @@ from mistralai import Mistral, models from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig from bulkgen.providers import Provider +from bulkgen.providers.models import ModelInfo def _image_to_data_url(path: Path) -> str: @@ -34,7 +35,7 @@ class TextProvider(Provider): target_name: str, target_config: TargetConfig, resolved_prompt: str, - resolved_model: str, + resolved_model: ModelInfo, project_dir: Path, ) -> None: output_path = project_dir / target_name @@ -57,7 +58,7 @@ class TextProvider(Provider): async with Mistral(api_key=self._api_key) as client: response = await client.chat.complete_async( - model=resolved_model, + model=resolved_model.name, messages=[message], ) diff --git a/tests/test_builder.py b/tests/test_builder.py index c8a8125..ae96868 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -18,6 +18,7 @@ from bulkgen.builder import ( ) from bulkgen.config import ProjectConfig, TargetConfig, TargetType from bulkgen.providers import Provider +from bulkgen.providers.models import ModelInfo from bulkgen.state import load_state WriteConfig = Callable[[dict[str, object]], ProjectConfig] @@ -32,7 +33,7 @@ class FakeProvider(Provider): target_name: str, target_config: TargetConfig, resolved_prompt: str, - resolved_model: str, + resolved_model: ModelInfo, project_dir: Path, ) -> None: output = project_dir / target_name @@ -48,7 +49,7 @@ class FailingProvider(Provider): target_name: str, target_config: TargetConfig, resolved_prompt: str, - resolved_model: str, + resolved_model: ModelInfo, project_dir: Path, ) -> None: msg = f"Simulated failure for {target_name}" @@ -247,7 +248,7 @@ class TestRunBuild: target_name: str, target_config: TargetConfig, resolved_prompt: str, - resolved_model: str, + resolved_model: ModelInfo, project_dir: Path, ) -> None: if target_name == "fail.txt": diff --git a/tests/test_config.py b/tests/test_config.py index b69419a..17fdf7d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -118,15 +118,23 @@ class TestResolveModel: """Test model resolution (explicit vs. default).""" def test_explicit_model_wins(self) -> None: - target = TargetConfig(prompt="x", model="my-model") - assert resolve_model("out.txt", target, Defaults()) == "my-model" + target = TargetConfig(prompt="x", model="mistral-small-latest") + result = resolve_model("out.txt", target, Defaults()) + assert result.name == "mistral-small-latest" def test_default_text_model(self) -> None: target = TargetConfig(prompt="x") - defaults = Defaults(text_model="custom-text") - assert resolve_model("out.md", target, defaults) == "custom-text" + defaults = Defaults(text_model="mistral-large-latest") + result = resolve_model("out.md", target, defaults) + assert result.name == "mistral-large-latest" def test_default_image_model(self) -> None: target = TargetConfig(prompt="x") - defaults = Defaults(image_model="custom-image") - assert resolve_model("out.png", target, defaults) == "custom-image" + defaults = Defaults(image_model="flux-dev") + result = resolve_model("out.png", target, defaults) + assert result.name == "flux-dev" + + def test_unknown_model_raises(self) -> None: + target = TargetConfig(prompt="x", model="nonexistent-model") + with pytest.raises(ValueError, match="Unknown model"): + _ = resolve_model("out.txt", target, Defaults()) diff --git a/tests/test_providers.py b/tests/test_providers.py index 38a3ca0..4f00e72 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -18,9 +18,19 @@ from bulkgen.providers.image import ImageProvider from bulkgen.providers.image import ( _encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage] ) +from bulkgen.providers.models import ALL_MODELS, ModelInfo from bulkgen.providers.text import TextProvider +def _model(name: str) -> ModelInfo: + """Look up a ModelInfo by name.""" + for m in ALL_MODELS: + if m.name == name: + return m + msg = f"Unknown test model: {name}" + raise ValueError(msg) + + def _make_bfl_mocks( image_bytes: bytes, ) -> tuple[BFLResult, MagicMock]: @@ -84,7 +94,7 @@ class TestImageProvider: target_name="out.png", target_config=target_config, resolved_prompt="A red square", - resolved_model="flux-pro-1.1", + resolved_model=_model("flux-pro-1.1"), project_dir=project_dir, ) @@ -111,7 +121,7 @@ class TestImageProvider: target_name="banner.png", target_config=target_config, resolved_prompt="A banner", - resolved_model="flux-pro-1.1", + resolved_model=_model("flux-pro-1.1"), project_dir=project_dir, ) @@ -142,7 +152,7 @@ class TestImageProvider: target_name="out.png", target_config=target_config, resolved_prompt="Like this", - resolved_model="flux-pro-1.1", + resolved_model=_model("flux-pro-1.1"), project_dir=project_dir, ) @@ -177,7 +187,7 @@ class TestImageProvider: target_name="out.png", target_config=target_config, resolved_prompt="Combine", - resolved_model="flux-2-pro", + resolved_model=_model("flux-2-pro"), project_dir=project_dir, ) @@ -208,7 +218,7 @@ class TestImageProvider: target_name="out.png", target_config=target_config, resolved_prompt="Edit", - resolved_model="flux-kontext-pro", + resolved_model=_model("flux-kontext-pro"), project_dir=project_dir, ) @@ -232,7 +242,7 @@ class TestImageProvider: target_name="fail.png", target_config=target_config, resolved_prompt="x", - resolved_model="flux-pro-1.1", + resolved_model=_model("flux-pro-1.1"), project_dir=project_dir, ) @@ -260,7 +270,7 @@ class TestTextProvider: target_name="poem.txt", target_config=target_config, resolved_prompt="Write a poem", - resolved_model="mistral-large-latest", + resolved_model=_model("mistral-large-latest"), project_dir=project_dir, ) @@ -282,7 +292,7 @@ class TestTextProvider: target_name="summary.md", target_config=target_config, resolved_prompt="Summarize", - resolved_model="mistral-large-latest", + resolved_model=_model("mistral-large-latest"), project_dir=project_dir, ) @@ -306,7 +316,7 @@ class TestTextProvider: target_name="desc.txt", target_config=target_config, resolved_prompt="Describe this image", - resolved_model="mistral-large-latest", + resolved_model=_model("mistral-large-latest"), project_dir=project_dir, ) @@ -331,7 +341,7 @@ class TestTextProvider: target_name="fail.txt", target_config=target_config, resolved_prompt="x", - resolved_model="mistral-large-latest", + resolved_model=_model("mistral-large-latest"), project_dir=project_dir, ) @@ -348,7 +358,7 @@ class TestTextProvider: target_name="fail.txt", target_config=target_config, resolved_prompt="x", - resolved_model="mistral-large-latest", + resolved_model=_model("mistral-large-latest"), project_dir=project_dir, ) @@ -371,7 +381,7 @@ class TestTextProvider: target_name="out.md", target_config=target_config, resolved_prompt="Combine all", - resolved_model="mistral-large-latest", + resolved_model=_model("mistral-large-latest"), project_dir=project_dir, ) @@ -402,7 +412,7 @@ class TestTextProvider: target_name="desc.txt", target_config=target_config, resolved_prompt="Describe the style", - resolved_model="mistral-large-latest", + resolved_model=_model("mistral-large-latest"), project_dir=project_dir, )