refactor: pass ModelInfo instead of model name string through provider interface

This commit is contained in:
Konstantin Fickel 2026-02-15 08:12:04 +01:00
parent 8e3ed7010f
commit d15444bdb0
Signed by: kfickel
GPG key ID: A793722F9933C1A5
8 changed files with 83 additions and 43 deletions

View file

@ -121,7 +121,7 @@ async def _build_single_target(
"""Build a single target by dispatching to the appropriate provider."""
target_cfg = config.targets[target_name]
target_type = infer_target_type(target_name)
model = resolve_model(target_name, target_cfg, config.defaults)
model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
provider = providers[target_type]
@ -129,7 +129,7 @@ async def _build_single_target(
target_name=target_name,
target_config=target_cfg,
resolved_prompt=resolved_prompt,
resolved_model=model,
resolved_model=model_info,
project_dir=project_dir,
)
@ -218,7 +218,7 @@ def _is_dirty(
) -> bool:
"""Check if a target needs rebuilding."""
target_cfg = config.targets[target_name]
model = resolve_model(target_name, target_cfg, config.defaults)
model_info = resolve_model(target_name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(target_name, config, project_dir)
extra = _collect_extra_params(target_name, config)
@ -226,7 +226,7 @@ def _is_dirty(
return is_target_dirty(
target_name,
resolved_prompt=resolved_prompt,
model=model,
model=model_info.name,
dep_files=dep_files,
extra_params=extra,
state=state,
@ -286,7 +286,7 @@ def _process_outcomes(
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
else:
target_cfg = config.targets[name]
model = resolve_model(name, target_cfg, config.defaults)
model_info = resolve_model(name, target_cfg, config.defaults)
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
dep_files = _collect_dep_files(name, config, project_dir)
extra = _collect_extra_params(name, config)
@ -294,7 +294,7 @@ def _process_outcomes(
record_target_state(
name,
resolved_prompt=resolved_prompt,
model=model,
model=model_info.name,
dep_files=dep_files,
extra_params=extra,
state=state,

View file

@ -4,11 +4,14 @@ from __future__ import annotations
import enum
from pathlib import Path
from typing import Self
from typing import TYPE_CHECKING, Self
import yaml
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from bulkgen.providers.models import ModelInfo
class TargetType(enum.Enum):
"""The kind of artifact a target produces."""
@ -65,14 +68,29 @@ def infer_target_type(target_name: str) -> TargetType:
raise ValueError(msg)
def resolve_model(target_name: str, target: TargetConfig, defaults: Defaults) -> str:
"""Return the effective model for a target (explicit or default by type)."""
def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target (explicit or default by type).
Raises :class:`ValueError` if the resolved model name is not in the registry.
"""
from bulkgen.providers.models import ALL_MODELS
if target.model is not None:
return target.model
model_name = target.model
else:
target_type = infer_target_type(target_name)
if target_type is TargetType.IMAGE:
return defaults.image_model
return defaults.text_model
model_name = (
defaults.image_model
if target_type is TargetType.IMAGE
else defaults.text_model
)
for model in ALL_MODELS:
if model.name == model_name:
return model
msg = f"Unknown model '{model_name}' for target '{target_name}'"
raise ValueError(msg)
def load_config(config_path: Path) -> ProjectConfig:

View file

@ -6,6 +6,7 @@ import abc
from pathlib import Path
from bulkgen.config import TargetConfig
from bulkgen.providers.models import ModelInfo
class Provider(abc.ABC):
@ -17,7 +18,7 @@ class Provider(abc.ABC):
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
"""Generate the target artifact and write it to *project_dir / target_name*.
@ -26,6 +27,6 @@ class Provider(abc.ABC):
target_name: Output filename (relative to project_dir).
target_config: The parsed target configuration.
resolved_prompt: The fully resolved prompt text.
resolved_model: The resolved model name.
resolved_model: The resolved model information.
project_dir: The project working directory.
"""

View file

@ -11,6 +11,7 @@ import httpx
from bulkgen.config import TargetConfig
from bulkgen.providers import Provider
from bulkgen.providers.bfl import BFLClient
from bulkgen.providers.models import ModelInfo
def _encode_image_b64(path: Path) -> str:
@ -58,7 +59,7 @@ class ImageProvider(Provider):
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
output_path = project_dir / target_name
@ -72,14 +73,14 @@ class ImageProvider(Provider):
if target_config.reference_images:
_add_reference_images(
inputs, target_config.reference_images, resolved_model, project_dir
inputs, target_config.reference_images, resolved_model.name, project_dir
)
for control_name in target_config.control_images:
ctrl_path = project_dir / control_name
inputs["control_image"] = _encode_image_b64(ctrl_path)
result = await self._client.generate(resolved_model, inputs)
result = await self._client.generate(resolved_model.name, inputs)
async with httpx.AsyncClient() as http:
response = await http.get(result.sample_url)

View file

@ -11,6 +11,7 @@ from mistralai import Mistral, models
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
from bulkgen.providers import Provider
from bulkgen.providers.models import ModelInfo
def _image_to_data_url(path: Path) -> str:
@ -34,7 +35,7 @@ class TextProvider(Provider):
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
output_path = project_dir / target_name
@ -57,7 +58,7 @@ class TextProvider(Provider):
async with Mistral(api_key=self._api_key) as client:
response = await client.chat.complete_async(
model=resolved_model,
model=resolved_model.name,
messages=[message],
)

View file

@ -18,6 +18,7 @@ from bulkgen.builder import (
)
from bulkgen.config import ProjectConfig, TargetConfig, TargetType
from bulkgen.providers import Provider
from bulkgen.providers.models import ModelInfo
from bulkgen.state import load_state
WriteConfig = Callable[[dict[str, object]], ProjectConfig]
@ -32,7 +33,7 @@ class FakeProvider(Provider):
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
output = project_dir / target_name
@ -48,7 +49,7 @@ class FailingProvider(Provider):
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
msg = f"Simulated failure for {target_name}"
@ -247,7 +248,7 @@ class TestRunBuild:
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
if target_name == "fail.txt":

View file

@ -118,15 +118,23 @@ class TestResolveModel:
"""Test model resolution (explicit vs. default)."""
def test_explicit_model_wins(self) -> None:
target = TargetConfig(prompt="x", model="my-model")
assert resolve_model("out.txt", target, Defaults()) == "my-model"
target = TargetConfig(prompt="x", model="mistral-small-latest")
result = resolve_model("out.txt", target, Defaults())
assert result.name == "mistral-small-latest"
def test_default_text_model(self) -> None:
target = TargetConfig(prompt="x")
defaults = Defaults(text_model="custom-text")
assert resolve_model("out.md", target, defaults) == "custom-text"
defaults = Defaults(text_model="mistral-large-latest")
result = resolve_model("out.md", target, defaults)
assert result.name == "mistral-large-latest"
def test_default_image_model(self) -> None:
target = TargetConfig(prompt="x")
defaults = Defaults(image_model="custom-image")
assert resolve_model("out.png", target, defaults) == "custom-image"
defaults = Defaults(image_model="flux-dev")
result = resolve_model("out.png", target, defaults)
assert result.name == "flux-dev"
def test_unknown_model_raises(self) -> None:
target = TargetConfig(prompt="x", model="nonexistent-model")
with pytest.raises(ValueError, match="Unknown model"):
_ = resolve_model("out.txt", target, Defaults())

View file

@ -18,9 +18,19 @@ from bulkgen.providers.image import ImageProvider
from bulkgen.providers.image import (
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
)
from bulkgen.providers.models import ALL_MODELS, ModelInfo
from bulkgen.providers.text import TextProvider
def _model(name: str) -> ModelInfo:
"""Look up a ModelInfo by name."""
for m in ALL_MODELS:
if m.name == name:
return m
msg = f"Unknown test model: {name}"
raise ValueError(msg)
def _make_bfl_mocks(
image_bytes: bytes,
) -> tuple[BFLResult, MagicMock]:
@ -84,7 +94,7 @@ class TestImageProvider:
target_name="out.png",
target_config=target_config,
resolved_prompt="A red square",
resolved_model="flux-pro-1.1",
resolved_model=_model("flux-pro-1.1"),
project_dir=project_dir,
)
@ -111,7 +121,7 @@ class TestImageProvider:
target_name="banner.png",
target_config=target_config,
resolved_prompt="A banner",
resolved_model="flux-pro-1.1",
resolved_model=_model("flux-pro-1.1"),
project_dir=project_dir,
)
@ -142,7 +152,7 @@ class TestImageProvider:
target_name="out.png",
target_config=target_config,
resolved_prompt="Like this",
resolved_model="flux-pro-1.1",
resolved_model=_model("flux-pro-1.1"),
project_dir=project_dir,
)
@ -177,7 +187,7 @@ class TestImageProvider:
target_name="out.png",
target_config=target_config,
resolved_prompt="Combine",
resolved_model="flux-2-pro",
resolved_model=_model("flux-2-pro"),
project_dir=project_dir,
)
@ -208,7 +218,7 @@ class TestImageProvider:
target_name="out.png",
target_config=target_config,
resolved_prompt="Edit",
resolved_model="flux-kontext-pro",
resolved_model=_model("flux-kontext-pro"),
project_dir=project_dir,
)
@ -232,7 +242,7 @@ class TestImageProvider:
target_name="fail.png",
target_config=target_config,
resolved_prompt="x",
resolved_model="flux-pro-1.1",
resolved_model=_model("flux-pro-1.1"),
project_dir=project_dir,
)
@ -260,7 +270,7 @@ class TestTextProvider:
target_name="poem.txt",
target_config=target_config,
resolved_prompt="Write a poem",
resolved_model="mistral-large-latest",
resolved_model=_model("mistral-large-latest"),
project_dir=project_dir,
)
@ -282,7 +292,7 @@ class TestTextProvider:
target_name="summary.md",
target_config=target_config,
resolved_prompt="Summarize",
resolved_model="mistral-large-latest",
resolved_model=_model("mistral-large-latest"),
project_dir=project_dir,
)
@ -306,7 +316,7 @@ class TestTextProvider:
target_name="desc.txt",
target_config=target_config,
resolved_prompt="Describe this image",
resolved_model="mistral-large-latest",
resolved_model=_model("mistral-large-latest"),
project_dir=project_dir,
)
@ -331,7 +341,7 @@ class TestTextProvider:
target_name="fail.txt",
target_config=target_config,
resolved_prompt="x",
resolved_model="mistral-large-latest",
resolved_model=_model("mistral-large-latest"),
project_dir=project_dir,
)
@ -348,7 +358,7 @@ class TestTextProvider:
target_name="fail.txt",
target_config=target_config,
resolved_prompt="x",
resolved_model="mistral-large-latest",
resolved_model=_model("mistral-large-latest"),
project_dir=project_dir,
)
@ -371,7 +381,7 @@ class TestTextProvider:
target_name="out.md",
target_config=target_config,
resolved_prompt="Combine all",
resolved_model="mistral-large-latest",
resolved_model=_model("mistral-large-latest"),
project_dir=project_dir,
)
@ -402,7 +412,7 @@ class TestTextProvider:
target_name="desc.txt",
target_config=target_config,
resolved_prompt="Describe the style",
resolved_model="mistral-large-latest",
resolved_model=_model("mistral-large-latest"),
project_dir=project_dir,
)