refactor: pass ModelInfo instead of model name string through provider interface
This commit is contained in:
parent
8e3ed7010f
commit
d15444bdb0
8 changed files with 83 additions and 43 deletions
|
|
@ -18,6 +18,7 @@ from bulkgen.builder import (
|
|||
)
|
||||
from bulkgen.config import ProjectConfig, TargetConfig, TargetType
|
||||
from bulkgen.providers import Provider
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.state import load_state
|
||||
|
||||
WriteConfig = Callable[[dict[str, object]], ProjectConfig]
|
||||
|
|
@ -32,7 +33,7 @@ class FakeProvider(Provider):
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
output = project_dir / target_name
|
||||
|
|
@ -48,7 +49,7 @@ class FailingProvider(Provider):
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
msg = f"Simulated failure for {target_name}"
|
||||
|
|
@ -247,7 +248,7 @@ class TestRunBuild:
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
if target_name == "fail.txt":
|
||||
|
|
|
|||
|
|
@ -118,15 +118,23 @@ class TestResolveModel:
|
|||
"""Test model resolution (explicit vs. default)."""
|
||||
|
||||
def test_explicit_model_wins(self) -> None:
|
||||
target = TargetConfig(prompt="x", model="my-model")
|
||||
assert resolve_model("out.txt", target, Defaults()) == "my-model"
|
||||
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
||||
result = resolve_model("out.txt", target, Defaults())
|
||||
assert result.name == "mistral-small-latest"
|
||||
|
||||
def test_default_text_model(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
defaults = Defaults(text_model="custom-text")
|
||||
assert resolve_model("out.md", target, defaults) == "custom-text"
|
||||
defaults = Defaults(text_model="mistral-large-latest")
|
||||
result = resolve_model("out.md", target, defaults)
|
||||
assert result.name == "mistral-large-latest"
|
||||
|
||||
def test_default_image_model(self) -> None:
|
||||
target = TargetConfig(prompt="x")
|
||||
defaults = Defaults(image_model="custom-image")
|
||||
assert resolve_model("out.png", target, defaults) == "custom-image"
|
||||
defaults = Defaults(image_model="flux-dev")
|
||||
result = resolve_model("out.png", target, defaults)
|
||||
assert result.name == "flux-dev"
|
||||
|
||||
def test_unknown_model_raises(self) -> None:
|
||||
target = TargetConfig(prompt="x", model="nonexistent-model")
|
||||
with pytest.raises(ValueError, match="Unknown model"):
|
||||
_ = resolve_model("out.txt", target, Defaults())
|
||||
|
|
|
|||
|
|
@ -18,9 +18,19 @@ from bulkgen.providers.image import ImageProvider
|
|||
from bulkgen.providers.image import (
|
||||
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from bulkgen.providers.models import ALL_MODELS, ModelInfo
|
||||
from bulkgen.providers.text import TextProvider
|
||||
|
||||
|
||||
def _model(name: str) -> ModelInfo:
|
||||
"""Look up a ModelInfo by name."""
|
||||
for m in ALL_MODELS:
|
||||
if m.name == name:
|
||||
return m
|
||||
msg = f"Unknown test model: {name}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _make_bfl_mocks(
|
||||
image_bytes: bytes,
|
||||
) -> tuple[BFLResult, MagicMock]:
|
||||
|
|
@ -84,7 +94,7 @@ class TestImageProvider:
|
|||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="A red square",
|
||||
resolved_model="flux-pro-1.1",
|
||||
resolved_model=_model("flux-pro-1.1"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -111,7 +121,7 @@ class TestImageProvider:
|
|||
target_name="banner.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="A banner",
|
||||
resolved_model="flux-pro-1.1",
|
||||
resolved_model=_model("flux-pro-1.1"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -142,7 +152,7 @@ class TestImageProvider:
|
|||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Like this",
|
||||
resolved_model="flux-pro-1.1",
|
||||
resolved_model=_model("flux-pro-1.1"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -177,7 +187,7 @@ class TestImageProvider:
|
|||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Combine",
|
||||
resolved_model="flux-2-pro",
|
||||
resolved_model=_model("flux-2-pro"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -208,7 +218,7 @@ class TestImageProvider:
|
|||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Edit",
|
||||
resolved_model="flux-kontext-pro",
|
||||
resolved_model=_model("flux-kontext-pro"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -232,7 +242,7 @@ class TestImageProvider:
|
|||
target_name="fail.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="x",
|
||||
resolved_model="flux-pro-1.1",
|
||||
resolved_model=_model("flux-pro-1.1"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -260,7 +270,7 @@ class TestTextProvider:
|
|||
target_name="poem.txt",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Write a poem",
|
||||
resolved_model="mistral-large-latest",
|
||||
resolved_model=_model("mistral-large-latest"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -282,7 +292,7 @@ class TestTextProvider:
|
|||
target_name="summary.md",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Summarize",
|
||||
resolved_model="mistral-large-latest",
|
||||
resolved_model=_model("mistral-large-latest"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -306,7 +316,7 @@ class TestTextProvider:
|
|||
target_name="desc.txt",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Describe this image",
|
||||
resolved_model="mistral-large-latest",
|
||||
resolved_model=_model("mistral-large-latest"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -331,7 +341,7 @@ class TestTextProvider:
|
|||
target_name="fail.txt",
|
||||
target_config=target_config,
|
||||
resolved_prompt="x",
|
||||
resolved_model="mistral-large-latest",
|
||||
resolved_model=_model("mistral-large-latest"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -348,7 +358,7 @@ class TestTextProvider:
|
|||
target_name="fail.txt",
|
||||
target_config=target_config,
|
||||
resolved_prompt="x",
|
||||
resolved_model="mistral-large-latest",
|
||||
resolved_model=_model("mistral-large-latest"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -371,7 +381,7 @@ class TestTextProvider:
|
|||
target_name="out.md",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Combine all",
|
||||
resolved_model="mistral-large-latest",
|
||||
resolved_model=_model("mistral-large-latest"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -402,7 +412,7 @@ class TestTextProvider:
|
|||
target_name="desc.txt",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Describe the style",
|
||||
resolved_model="mistral-large-latest",
|
||||
resolved_model=_model("mistral-large-latest"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue