refactor: pass ModelInfo instead of model name string through provider interface

This commit is contained in:
Konstantin Fickel 2026-02-15 08:12:04 +01:00
parent 8e3ed7010f
commit d15444bdb0
Signed by: kfickel
GPG key ID: A793722F9933C1A5
8 changed files with 83 additions and 43 deletions

View file

@ -121,7 +121,7 @@ async def _build_single_target(
"""Build a single target by dispatching to the appropriate provider.""" """Build a single target by dispatching to the appropriate provider."""
target_cfg = config.targets[target_name] target_cfg = config.targets[target_name]
target_type = infer_target_type(target_name) target_type = infer_target_type(target_name)
model = resolve_model(target_name, target_cfg, config.defaults) model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
provider = providers[target_type] provider = providers[target_type]
@ -129,7 +129,7 @@ async def _build_single_target(
target_name=target_name, target_name=target_name,
target_config=target_cfg, target_config=target_cfg,
resolved_prompt=resolved_prompt, resolved_prompt=resolved_prompt,
resolved_model=model, resolved_model=model_info,
project_dir=project_dir, project_dir=project_dir,
) )
@ -218,7 +218,7 @@ def _is_dirty(
) -> bool: ) -> bool:
"""Check if a target needs rebuilding.""" """Check if a target needs rebuilding."""
target_cfg = config.targets[target_name] target_cfg = config.targets[target_name]
model = resolve_model(target_name, target_cfg, config.defaults) model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(target_name, config, project_dir) dep_files = _collect_dep_files(target_name, config, project_dir)
extra = _collect_extra_params(target_name, config) extra = _collect_extra_params(target_name, config)
@ -226,7 +226,7 @@ def _is_dirty(
return is_target_dirty( return is_target_dirty(
target_name, target_name,
resolved_prompt=resolved_prompt, resolved_prompt=resolved_prompt,
model=model, model=model_info.name,
dep_files=dep_files, dep_files=dep_files,
extra_params=extra, extra_params=extra,
state=state, state=state,
@ -286,7 +286,7 @@ def _process_outcomes(
on_progress(BuildEvent.TARGET_FAILED, name, str(error)) on_progress(BuildEvent.TARGET_FAILED, name, str(error))
else: else:
target_cfg = config.targets[name] target_cfg = config.targets[name]
model = resolve_model(name, target_cfg, config.defaults) model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir) resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(name, config, project_dir) dep_files = _collect_dep_files(name, config, project_dir)
extra = _collect_extra_params(name, config) extra = _collect_extra_params(name, config)
@ -294,7 +294,7 @@ def _process_outcomes(
record_target_state( record_target_state(
name, name,
resolved_prompt=resolved_prompt, resolved_prompt=resolved_prompt,
model=model, model=model_info.name,
dep_files=dep_files, dep_files=dep_files,
extra_params=extra, extra_params=extra,
state=state, state=state,

View file

@ -4,11 +4,14 @@ from __future__ import annotations
import enum import enum
from pathlib import Path from pathlib import Path
from typing import Self from typing import TYPE_CHECKING, Self
import yaml import yaml
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from bulkgen.providers.models import ModelInfo
class TargetType(enum.Enum): class TargetType(enum.Enum):
"""The kind of artifact a target produces.""" """The kind of artifact a target produces."""
@ -65,14 +68,29 @@ def infer_target_type(target_name: str) -> TargetType:
raise ValueError(msg) raise ValueError(msg)
def resolve_model(target_name: str, target: TargetConfig, defaults: Defaults) -> str: def resolve_model(
"""Return the effective model for a target (explicit or default by type).""" target_name: str, target: TargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target (explicit or default by type).
Raises :class:`ValueError` if the resolved model name is not in the registry.
"""
from bulkgen.providers.models import ALL_MODELS
if target.model is not None: if target.model is not None:
return target.model model_name = target.model
else:
target_type = infer_target_type(target_name) target_type = infer_target_type(target_name)
if target_type is TargetType.IMAGE: model_name = (
return defaults.image_model defaults.image_model
return defaults.text_model if target_type is TargetType.IMAGE
else defaults.text_model
)
for model in ALL_MODELS:
if model.name == model_name:
return model
msg = f"Unknown model '{model_name}' for target '{target_name}'"
raise ValueError(msg)
def load_config(config_path: Path) -> ProjectConfig: def load_config(config_path: Path) -> ProjectConfig:

View file

@ -6,6 +6,7 @@ import abc
from pathlib import Path from pathlib import Path
from bulkgen.config import TargetConfig from bulkgen.config import TargetConfig
from bulkgen.providers.models import ModelInfo
class Provider(abc.ABC): class Provider(abc.ABC):
@ -17,7 +18,7 @@ class Provider(abc.ABC):
target_name: str, target_name: str,
target_config: TargetConfig, target_config: TargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: str, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
) -> None: ) -> None:
"""Generate the target artifact and write it to *project_dir / target_name*. """Generate the target artifact and write it to *project_dir / target_name*.
@ -26,6 +27,6 @@ class Provider(abc.ABC):
target_name: Output filename (relative to project_dir). target_name: Output filename (relative to project_dir).
target_config: The parsed target configuration. target_config: The parsed target configuration.
resolved_prompt: The fully resolved prompt text. resolved_prompt: The fully resolved prompt text.
resolved_model: The resolved model name. resolved_model: The resolved model information.
project_dir: The project working directory. project_dir: The project working directory.
""" """

View file

@ -11,6 +11,7 @@ import httpx
from bulkgen.config import TargetConfig from bulkgen.config import TargetConfig
from bulkgen.providers import Provider from bulkgen.providers import Provider
from bulkgen.providers.bfl import BFLClient from bulkgen.providers.bfl import BFLClient
from bulkgen.providers.models import ModelInfo
def _encode_image_b64(path: Path) -> str: def _encode_image_b64(path: Path) -> str:
@ -58,7 +59,7 @@ class ImageProvider(Provider):
target_name: str, target_name: str,
target_config: TargetConfig, target_config: TargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: str, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
) -> None: ) -> None:
output_path = project_dir / target_name output_path = project_dir / target_name
@ -72,14 +73,14 @@ class ImageProvider(Provider):
if target_config.reference_images: if target_config.reference_images:
_add_reference_images( _add_reference_images(
inputs, target_config.reference_images, resolved_model, project_dir inputs, target_config.reference_images, resolved_model.name, project_dir
) )
for control_name in target_config.control_images: for control_name in target_config.control_images:
ctrl_path = project_dir / control_name ctrl_path = project_dir / control_name
inputs["control_image"] = _encode_image_b64(ctrl_path) inputs["control_image"] = _encode_image_b64(ctrl_path)
result = await self._client.generate(resolved_model, inputs) result = await self._client.generate(resolved_model.name, inputs)
async with httpx.AsyncClient() as http: async with httpx.AsyncClient() as http:
response = await http.get(result.sample_url) response = await http.get(result.sample_url)

View file

@ -11,6 +11,7 @@ from mistralai import Mistral, models
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
from bulkgen.providers import Provider from bulkgen.providers import Provider
from bulkgen.providers.models import ModelInfo
def _image_to_data_url(path: Path) -> str: def _image_to_data_url(path: Path) -> str:
@ -34,7 +35,7 @@ class TextProvider(Provider):
target_name: str, target_name: str,
target_config: TargetConfig, target_config: TargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: str, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
) -> None: ) -> None:
output_path = project_dir / target_name output_path = project_dir / target_name
@ -57,7 +58,7 @@ class TextProvider(Provider):
async with Mistral(api_key=self._api_key) as client: async with Mistral(api_key=self._api_key) as client:
response = await client.chat.complete_async( response = await client.chat.complete_async(
model=resolved_model, model=resolved_model.name,
messages=[message], messages=[message],
) )

View file

@ -18,6 +18,7 @@ from bulkgen.builder import (
) )
from bulkgen.config import ProjectConfig, TargetConfig, TargetType from bulkgen.config import ProjectConfig, TargetConfig, TargetType
from bulkgen.providers import Provider from bulkgen.providers import Provider
from bulkgen.providers.models import ModelInfo
from bulkgen.state import load_state from bulkgen.state import load_state
WriteConfig = Callable[[dict[str, object]], ProjectConfig] WriteConfig = Callable[[dict[str, object]], ProjectConfig]
@ -32,7 +33,7 @@ class FakeProvider(Provider):
target_name: str, target_name: str,
target_config: TargetConfig, target_config: TargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: str, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
) -> None: ) -> None:
output = project_dir / target_name output = project_dir / target_name
@ -48,7 +49,7 @@ class FailingProvider(Provider):
target_name: str, target_name: str,
target_config: TargetConfig, target_config: TargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: str, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
) -> None: ) -> None:
msg = f"Simulated failure for {target_name}" msg = f"Simulated failure for {target_name}"
@ -247,7 +248,7 @@ class TestRunBuild:
target_name: str, target_name: str,
target_config: TargetConfig, target_config: TargetConfig,
resolved_prompt: str, resolved_prompt: str,
resolved_model: str, resolved_model: ModelInfo,
project_dir: Path, project_dir: Path,
) -> None: ) -> None:
if target_name == "fail.txt": if target_name == "fail.txt":

View file

@ -118,15 +118,23 @@ class TestResolveModel:
"""Test model resolution (explicit vs. default).""" """Test model resolution (explicit vs. default)."""
def test_explicit_model_wins(self) -> None: def test_explicit_model_wins(self) -> None:
target = TargetConfig(prompt="x", model="my-model") target = TargetConfig(prompt="x", model="mistral-small-latest")
assert resolve_model("out.txt", target, Defaults()) == "my-model" result = resolve_model("out.txt", target, Defaults())
assert result.name == "mistral-small-latest"
def test_default_text_model(self) -> None: def test_default_text_model(self) -> None:
target = TargetConfig(prompt="x") target = TargetConfig(prompt="x")
defaults = Defaults(text_model="custom-text") defaults = Defaults(text_model="mistral-large-latest")
assert resolve_model("out.md", target, defaults) == "custom-text" result = resolve_model("out.md", target, defaults)
assert result.name == "mistral-large-latest"
def test_default_image_model(self) -> None: def test_default_image_model(self) -> None:
target = TargetConfig(prompt="x") target = TargetConfig(prompt="x")
defaults = Defaults(image_model="custom-image") defaults = Defaults(image_model="flux-dev")
assert resolve_model("out.png", target, defaults) == "custom-image" result = resolve_model("out.png", target, defaults)
assert result.name == "flux-dev"
def test_unknown_model_raises(self) -> None:
target = TargetConfig(prompt="x", model="nonexistent-model")
with pytest.raises(ValueError, match="Unknown model"):
_ = resolve_model("out.txt", target, Defaults())

View file

@ -18,9 +18,19 @@ from bulkgen.providers.image import ImageProvider
from bulkgen.providers.image import ( from bulkgen.providers.image import (
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage] _encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
) )
from bulkgen.providers.models import ALL_MODELS, ModelInfo
from bulkgen.providers.text import TextProvider from bulkgen.providers.text import TextProvider
def _model(name: str) -> ModelInfo:
"""Look up a ModelInfo by name."""
for m in ALL_MODELS:
if m.name == name:
return m
msg = f"Unknown test model: {name}"
raise ValueError(msg)
def _make_bfl_mocks( def _make_bfl_mocks(
image_bytes: bytes, image_bytes: bytes,
) -> tuple[BFLResult, MagicMock]: ) -> tuple[BFLResult, MagicMock]:
@ -84,7 +94,7 @@ class TestImageProvider:
target_name="out.png", target_name="out.png",
target_config=target_config, target_config=target_config,
resolved_prompt="A red square", resolved_prompt="A red square",
resolved_model="flux-pro-1.1", resolved_model=_model("flux-pro-1.1"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -111,7 +121,7 @@ class TestImageProvider:
target_name="banner.png", target_name="banner.png",
target_config=target_config, target_config=target_config,
resolved_prompt="A banner", resolved_prompt="A banner",
resolved_model="flux-pro-1.1", resolved_model=_model("flux-pro-1.1"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -142,7 +152,7 @@ class TestImageProvider:
target_name="out.png", target_name="out.png",
target_config=target_config, target_config=target_config,
resolved_prompt="Like this", resolved_prompt="Like this",
resolved_model="flux-pro-1.1", resolved_model=_model("flux-pro-1.1"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -177,7 +187,7 @@ class TestImageProvider:
target_name="out.png", target_name="out.png",
target_config=target_config, target_config=target_config,
resolved_prompt="Combine", resolved_prompt="Combine",
resolved_model="flux-2-pro", resolved_model=_model("flux-2-pro"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -208,7 +218,7 @@ class TestImageProvider:
target_name="out.png", target_name="out.png",
target_config=target_config, target_config=target_config,
resolved_prompt="Edit", resolved_prompt="Edit",
resolved_model="flux-kontext-pro", resolved_model=_model("flux-kontext-pro"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -232,7 +242,7 @@ class TestImageProvider:
target_name="fail.png", target_name="fail.png",
target_config=target_config, target_config=target_config,
resolved_prompt="x", resolved_prompt="x",
resolved_model="flux-pro-1.1", resolved_model=_model("flux-pro-1.1"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -260,7 +270,7 @@ class TestTextProvider:
target_name="poem.txt", target_name="poem.txt",
target_config=target_config, target_config=target_config,
resolved_prompt="Write a poem", resolved_prompt="Write a poem",
resolved_model="mistral-large-latest", resolved_model=_model("mistral-large-latest"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -282,7 +292,7 @@ class TestTextProvider:
target_name="summary.md", target_name="summary.md",
target_config=target_config, target_config=target_config,
resolved_prompt="Summarize", resolved_prompt="Summarize",
resolved_model="mistral-large-latest", resolved_model=_model("mistral-large-latest"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -306,7 +316,7 @@ class TestTextProvider:
target_name="desc.txt", target_name="desc.txt",
target_config=target_config, target_config=target_config,
resolved_prompt="Describe this image", resolved_prompt="Describe this image",
resolved_model="mistral-large-latest", resolved_model=_model("mistral-large-latest"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -331,7 +341,7 @@ class TestTextProvider:
target_name="fail.txt", target_name="fail.txt",
target_config=target_config, target_config=target_config,
resolved_prompt="x", resolved_prompt="x",
resolved_model="mistral-large-latest", resolved_model=_model("mistral-large-latest"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -348,7 +358,7 @@ class TestTextProvider:
target_name="fail.txt", target_name="fail.txt",
target_config=target_config, target_config=target_config,
resolved_prompt="x", resolved_prompt="x",
resolved_model="mistral-large-latest", resolved_model=_model("mistral-large-latest"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -371,7 +381,7 @@ class TestTextProvider:
target_name="out.md", target_name="out.md",
target_config=target_config, target_config=target_config,
resolved_prompt="Combine all", resolved_prompt="Combine all",
resolved_model="mistral-large-latest", resolved_model=_model("mistral-large-latest"),
project_dir=project_dir, project_dir=project_dir,
) )
@ -402,7 +412,7 @@ class TestTextProvider:
target_name="desc.txt", target_name="desc.txt",
target_config=target_config, target_config=target_config,
resolved_prompt="Describe the style", resolved_prompt="Describe the style",
resolved_model="mistral-large-latest", resolved_model=_model("mistral-large-latest"),
project_dir=project_dir, project_dir=project_dir,
) )