refactor: move model definitions into providers and extract resolve module
- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider - Add get_provided_models() abstract method to Provider base class - Move model lists from models.py into each provider's get_provided_models() - Add providers/registry.py to aggregate models from all providers - Extract infer_required_capabilities and resolve_model from config.py to resolve.py - Update tests to use new names and import paths
This commit is contained in:
parent
dc6a75f5c4
commit
d0dac5b1bf
13 changed files with 432 additions and 390 deletions
|
|
@ -12,14 +12,13 @@ from pathlib import Path
|
||||||
from bulkgen.config import (
|
from bulkgen.config import (
|
||||||
ProjectConfig,
|
ProjectConfig,
|
||||||
TargetType,
|
TargetType,
|
||||||
infer_required_capabilities,
|
|
||||||
resolve_model,
|
|
||||||
target_type_from_capabilities,
|
target_type_from_capabilities,
|
||||||
)
|
)
|
||||||
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
|
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
|
||||||
from bulkgen.providers import Provider
|
from bulkgen.providers import Provider
|
||||||
from bulkgen.providers.image import ImageProvider
|
from bulkgen.providers.blackforest import BlackForestProvider
|
||||||
from bulkgen.providers.text import TextProvider
|
from bulkgen.providers.mistral import MistralProvider
|
||||||
|
from bulkgen.resolve import infer_required_capabilities, resolve_model
|
||||||
from bulkgen.state import (
|
from bulkgen.state import (
|
||||||
BuildState,
|
BuildState,
|
||||||
is_target_dirty,
|
is_target_dirty,
|
||||||
|
|
@ -106,10 +105,10 @@ def _create_providers() -> dict[TargetType, Provider]:
|
||||||
providers: dict[TargetType, Provider] = {}
|
providers: dict[TargetType, Provider] = {}
|
||||||
bfl_key = os.environ.get("BFL_API_KEY", "")
|
bfl_key = os.environ.get("BFL_API_KEY", "")
|
||||||
if bfl_key:
|
if bfl_key:
|
||||||
providers[TargetType.IMAGE] = ImageProvider(api_key=bfl_key)
|
providers[TargetType.IMAGE] = BlackForestProvider(api_key=bfl_key)
|
||||||
mistral_key = os.environ.get("MISTRAL_API_KEY", "")
|
mistral_key = os.environ.get("MISTRAL_API_KEY", "")
|
||||||
if mistral_key:
|
if mistral_key:
|
||||||
providers[TargetType.TEXT] = TextProvider(api_key=mistral_key)
|
providers[TargetType.TEXT] = MistralProvider(api_key=mistral_key)
|
||||||
return providers
|
return providers
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import typer
|
||||||
from bulkgen.builder import BuildEvent, BuildResult, run_build
|
from bulkgen.builder import BuildEvent, BuildResult, run_build
|
||||||
from bulkgen.config import ProjectConfig, load_config
|
from bulkgen.config import ProjectConfig, load_config
|
||||||
from bulkgen.graph import build_graph, get_build_order
|
from bulkgen.graph import build_graph, get_build_order
|
||||||
from bulkgen.providers.models import ALL_MODELS
|
from bulkgen.providers.registry import get_all_models
|
||||||
|
|
||||||
app = typer.Typer(name="bulkgen", help="AI artifact build tool.")
|
app = typer.Typer(name="bulkgen", help="AI artifact build tool.")
|
||||||
|
|
||||||
|
|
@ -182,9 +182,10 @@ def graph() -> None:
|
||||||
@app.command()
|
@app.command()
|
||||||
def models() -> None:
|
def models() -> None:
|
||||||
"""List available models and their capabilities."""
|
"""List available models and their capabilities."""
|
||||||
name_width = max(len(m.name) for m in ALL_MODELS)
|
all_models = get_all_models()
|
||||||
provider_width = max(len(m.provider) for m in ALL_MODELS)
|
name_width = max(len(m.name) for m in all_models)
|
||||||
type_width = max(len(m.type) for m in ALL_MODELS)
|
provider_width = max(len(m.provider) for m in all_models)
|
||||||
|
type_width = max(len(m.type) for m in all_models)
|
||||||
|
|
||||||
header_name = "Model".ljust(name_width)
|
header_name = "Model".ljust(name_width)
|
||||||
header_provider = "Provider".ljust(provider_width)
|
header_provider = "Provider".ljust(provider_width)
|
||||||
|
|
@ -210,7 +211,7 @@ def models() -> None:
|
||||||
+ "─" * len(header_caps)
|
+ "─" * len(header_caps)
|
||||||
)
|
)
|
||||||
|
|
||||||
for model in ALL_MODELS:
|
for model in all_models:
|
||||||
name_col = model.name.ljust(name_width)
|
name_col = model.name.ljust(name_width)
|
||||||
provider_col = model.provider.ljust(provider_width)
|
provider_col = model.provider.ljust(provider_width)
|
||||||
type_col = model.type.ljust(type_width)
|
type_col = model.type.ljust(type_width)
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,15 @@ from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Self
|
from typing import Self
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from bulkgen.providers.models import Capability
|
||||||
from bulkgen.providers.models import Capability, ModelInfo
|
|
||||||
|
IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"})
|
||||||
|
TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"})
|
||||||
|
|
||||||
|
|
||||||
class TargetType(enum.Enum):
|
class TargetType(enum.Enum):
|
||||||
|
|
@ -20,10 +22,6 @@ class TargetType(enum.Enum):
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
|
|
||||||
|
|
||||||
IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"})
|
|
||||||
TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"})
|
|
||||||
|
|
||||||
|
|
||||||
class Defaults(BaseModel):
|
class Defaults(BaseModel):
|
||||||
"""Default model names, applied when a target does not specify its own."""
|
"""Default model names, applied when a target does not specify its own."""
|
||||||
|
|
||||||
|
|
@ -57,36 +55,6 @@ class ProjectConfig(BaseModel):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def infer_required_capabilities(
|
|
||||||
target_name: str, target: TargetConfig
|
|
||||||
) -> frozenset[Capability]:
|
|
||||||
"""Infer the capabilities a model must have based on filename and config.
|
|
||||||
|
|
||||||
Raises :class:`ValueError` for unsupported file extensions.
|
|
||||||
"""
|
|
||||||
from bulkgen.providers.models import Capability
|
|
||||||
|
|
||||||
suffix = Path(target_name).suffix.lower()
|
|
||||||
caps: set[Capability] = set()
|
|
||||||
|
|
||||||
if suffix in IMAGE_EXTENSIONS:
|
|
||||||
caps.add(Capability.TEXT_TO_IMAGE)
|
|
||||||
if target.reference_images:
|
|
||||||
caps.add(Capability.REFERENCE_IMAGES)
|
|
||||||
if target.control_images:
|
|
||||||
caps.add(Capability.CONTROL_IMAGES)
|
|
||||||
elif suffix in TEXT_EXTENSIONS:
|
|
||||||
caps.add(Capability.TEXT_GENERATION)
|
|
||||||
all_input_names = list(target.inputs) + list(target.reference_images)
|
|
||||||
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
|
|
||||||
caps.add(Capability.VISION)
|
|
||||||
else:
|
|
||||||
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
return frozenset(caps)
|
|
||||||
|
|
||||||
|
|
||||||
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
|
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
|
||||||
"""Derive the target type from a set of required capabilities."""
|
"""Derive the target type from a set of required capabilities."""
|
||||||
from bulkgen.providers.models import Capability
|
from bulkgen.providers.models import Capability
|
||||||
|
|
@ -96,59 +64,6 @@ def target_type_from_capabilities(capabilities: frozenset[Capability]) -> Target
|
||||||
return TargetType.TEXT
|
return TargetType.TEXT
|
||||||
|
|
||||||
|
|
||||||
def resolve_model(
|
|
||||||
target_name: str, target: TargetConfig, defaults: Defaults
|
|
||||||
) -> ModelInfo:
|
|
||||||
"""Return the effective model for a target, validated against required capabilities.
|
|
||||||
|
|
||||||
If the target specifies an explicit model, it is validated to have all
|
|
||||||
required capabilities. Otherwise the type-appropriate default is tried
|
|
||||||
first; if it lacks a required capability the first capable model of the
|
|
||||||
same type is selected.
|
|
||||||
|
|
||||||
Raises :class:`ValueError` if no suitable model can be found.
|
|
||||||
"""
|
|
||||||
from bulkgen.providers.models import ALL_MODELS
|
|
||||||
|
|
||||||
required = infer_required_capabilities(target_name, target)
|
|
||||||
target_type = target_type_from_capabilities(required)
|
|
||||||
|
|
||||||
if target.model is not None:
|
|
||||||
# Explicit model — look up and validate.
|
|
||||||
for model in ALL_MODELS:
|
|
||||||
if model.name == target.model:
|
|
||||||
missing = required - frozenset(model.capabilities)
|
|
||||||
if missing:
|
|
||||||
names = ", ".join(sorted(missing))
|
|
||||||
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
return model
|
|
||||||
msg = f"Unknown model '{target.model}' for target '{target_name}'"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# No explicit model — try the default first, then fall back.
|
|
||||||
default_name = (
|
|
||||||
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
|
|
||||||
)
|
|
||||||
for model in ALL_MODELS:
|
|
||||||
if model.name == default_name:
|
|
||||||
if required <= frozenset(model.capabilities):
|
|
||||||
return model
|
|
||||||
break
|
|
||||||
|
|
||||||
# Default lacks capabilities — find the first capable model of the same type.
|
|
||||||
model_type = "image" if target_type is TargetType.IMAGE else "text"
|
|
||||||
for model in ALL_MODELS:
|
|
||||||
if model.type == model_type and required <= frozenset(model.capabilities):
|
|
||||||
return model
|
|
||||||
|
|
||||||
names = ", ".join(sorted(required))
|
|
||||||
msg = (
|
|
||||||
f"No model found for target '{target_name}' with required capabilities: {names}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: Path) -> ProjectConfig:
|
def load_config(config_path: Path) -> ProjectConfig:
|
||||||
"""Load and validate a ``.bulkgen.yaml`` file."""
|
"""Load and validate a ``.bulkgen.yaml`` file."""
|
||||||
with config_path.open() as f:
|
with config_path.open() as f:
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,11 @@ from bulkgen.providers.models import ModelInfo
|
||||||
class Provider(abc.ABC):
|
class Provider(abc.ABC):
|
||||||
"""Abstract base for generation providers."""
|
"""Abstract base for generation providers."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_provided_models() -> list[ModelInfo]:
|
||||||
|
"""Return the models this provider supports."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import httpx
|
||||||
from bulkgen.config import TargetConfig
|
from bulkgen.config import TargetConfig
|
||||||
from bulkgen.providers import Provider
|
from bulkgen.providers import Provider
|
||||||
from bulkgen.providers.bfl import BFLClient
|
from bulkgen.providers.bfl import BFLClient
|
||||||
from bulkgen.providers.models import ModelInfo
|
from bulkgen.providers.models import Capability, ModelInfo
|
||||||
|
|
||||||
|
|
||||||
def _encode_image_b64(path: Path) -> str:
|
def _encode_image_b64(path: Path) -> str:
|
||||||
|
|
@ -45,7 +45,7 @@ def _add_reference_images(
|
||||||
inputs[key] = _encode_image_b64(project_dir / ref_name)
|
inputs[key] = _encode_image_b64(project_dir / ref_name)
|
||||||
|
|
||||||
|
|
||||||
class ImageProvider(Provider):
|
class BlackForestProvider(Provider):
|
||||||
"""Generates images via the BlackForestLabs API."""
|
"""Generates images via the BlackForestLabs API."""
|
||||||
|
|
||||||
_client: BFLClient
|
_client: BFLClient
|
||||||
|
|
@ -53,6 +53,72 @@ class ImageProvider(Provider):
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self._client = BFLClient(api_key=api_key)
|
self._client = BFLClient(api_key=api_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@override
|
||||||
|
def get_provided_models() -> list[ModelInfo]:
|
||||||
|
return [
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-dev",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-pro",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-pro-1.1",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-pro-1.1-ultra",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-2-pro",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-kontext-pro",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-pro-1.0-canny",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-pro-1.0-depth",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-pro-1.0-fill",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="flux-pro-1.0-expand",
|
||||||
|
provider="BlackForestLabs",
|
||||||
|
type="image",
|
||||||
|
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -11,7 +11,7 @@ from mistralai import Mistral, models
|
||||||
|
|
||||||
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
|
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
|
||||||
from bulkgen.providers import Provider
|
from bulkgen.providers import Provider
|
||||||
from bulkgen.providers.models import ModelInfo
|
from bulkgen.providers.models import Capability, ModelInfo
|
||||||
|
|
||||||
|
|
||||||
def _image_to_data_url(path: Path) -> str:
|
def _image_to_data_url(path: Path) -> str:
|
||||||
|
|
@ -21,7 +21,7 @@ def _image_to_data_url(path: Path) -> str:
|
||||||
return f"data:{mime};base64,{b64}"
|
return f"data:{mime};base64,{b64}"
|
||||||
|
|
||||||
|
|
||||||
class TextProvider(Provider):
|
class MistralProvider(Provider):
|
||||||
"""Generates text via the Mistral API."""
|
"""Generates text via the Mistral API."""
|
||||||
|
|
||||||
_api_key: str
|
_api_key: str
|
||||||
|
|
@ -29,6 +29,36 @@ class TextProvider(Provider):
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@override
|
||||||
|
def get_provided_models() -> list[ModelInfo]:
|
||||||
|
return [
|
||||||
|
ModelInfo(
|
||||||
|
name="mistral-large-latest",
|
||||||
|
provider="Mistral",
|
||||||
|
type="text",
|
||||||
|
capabilities=[Capability.TEXT_GENERATION],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="mistral-small-latest",
|
||||||
|
provider="Mistral",
|
||||||
|
type="text",
|
||||||
|
capabilities=[Capability.TEXT_GENERATION],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="pixtral-large-latest",
|
||||||
|
provider="Mistral",
|
||||||
|
type="text",
|
||||||
|
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
name="pixtral-12b-latest",
|
||||||
|
provider="Mistral",
|
||||||
|
type="text",
|
||||||
|
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -98,6 +128,8 @@ def _build_multimodal_message(
|
||||||
models.TextChunk(text=prompt),
|
models.TextChunk(text=prompt),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
from bulkgen.config import IMAGE_EXTENSIONS
|
||||||
|
|
||||||
for name in input_names:
|
for name in input_names:
|
||||||
input_path = project_dir / name
|
input_path = project_dir / name
|
||||||
suffix = input_path.suffix.lower()
|
suffix = input_path.suffix.lower()
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Registry of supported models and their capabilities."""
|
"""Model types and capability definitions for AI providers."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -25,106 +25,3 @@ class ModelInfo:
|
||||||
provider: str
|
provider: str
|
||||||
type: Literal["text", "image"]
|
type: Literal["text", "image"]
|
||||||
capabilities: list[Capability]
|
capabilities: list[Capability]
|
||||||
|
|
||||||
|
|
||||||
TEXT_MODELS: list[ModelInfo] = [
|
|
||||||
ModelInfo(
|
|
||||||
name="mistral-large-latest",
|
|
||||||
provider="Mistral",
|
|
||||||
type="text",
|
|
||||||
capabilities=[Capability.TEXT_GENERATION],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="mistral-small-latest",
|
|
||||||
provider="Mistral",
|
|
||||||
type="text",
|
|
||||||
capabilities=[Capability.TEXT_GENERATION],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="pixtral-large-latest",
|
|
||||||
provider="Mistral",
|
|
||||||
type="text",
|
|
||||||
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="pixtral-12b-latest",
|
|
||||||
provider="Mistral",
|
|
||||||
type="text",
|
|
||||||
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
IMAGE_MODELS: list[ModelInfo] = [
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-dev",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-pro",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-pro-1.1",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-pro-1.1-ultra",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-2-pro",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-kontext-pro",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-pro-1.0-canny",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[
|
|
||||||
Capability.TEXT_TO_IMAGE,
|
|
||||||
Capability.CONTROL_IMAGES,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-pro-1.0-depth",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[
|
|
||||||
Capability.TEXT_TO_IMAGE,
|
|
||||||
Capability.CONTROL_IMAGES,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-pro-1.0-fill",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[
|
|
||||||
Capability.TEXT_TO_IMAGE,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
name="flux-pro-1.0-expand",
|
|
||||||
provider="BlackForestLabs",
|
|
||||||
type="image",
|
|
||||||
capabilities=[
|
|
||||||
Capability.TEXT_TO_IMAGE,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
ALL_MODELS: list[ModelInfo] = TEXT_MODELS + IMAGE_MODELS
|
|
||||||
|
|
|
||||||
16
bulkgen/providers/registry.py
Normal file
16
bulkgen/providers/registry.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Aggregates models from all registered providers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from bulkgen.providers.models import ModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_models() -> list[ModelInfo]:
|
||||||
|
"""Return the merged list of models from all providers."""
|
||||||
|
from bulkgen.providers.blackforest import BlackForestProvider
|
||||||
|
from bulkgen.providers.mistral import MistralProvider
|
||||||
|
|
||||||
|
return (
|
||||||
|
MistralProvider.get_provided_models()
|
||||||
|
+ BlackForestProvider.get_provided_models()
|
||||||
|
)
|
||||||
95
bulkgen/resolve.py
Normal file
95
bulkgen/resolve.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from bulkgen.config import (
|
||||||
|
IMAGE_EXTENSIONS,
|
||||||
|
TEXT_EXTENSIONS,
|
||||||
|
Defaults,
|
||||||
|
TargetConfig,
|
||||||
|
TargetType,
|
||||||
|
target_type_from_capabilities,
|
||||||
|
)
|
||||||
|
from bulkgen.providers.models import Capability, ModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
def infer_required_capabilities(
|
||||||
|
target_name: str, target: TargetConfig
|
||||||
|
) -> frozenset[Capability]:
|
||||||
|
"""Infer the capabilities a model must have based on filename and config.
|
||||||
|
|
||||||
|
Raises :class:`ValueError` for unsupported file extensions.
|
||||||
|
"""
|
||||||
|
suffix = Path(target_name).suffix.lower()
|
||||||
|
caps: set[Capability] = set()
|
||||||
|
|
||||||
|
if suffix in IMAGE_EXTENSIONS:
|
||||||
|
caps.add(Capability.TEXT_TO_IMAGE)
|
||||||
|
if target.reference_images:
|
||||||
|
caps.add(Capability.REFERENCE_IMAGES)
|
||||||
|
if target.control_images:
|
||||||
|
caps.add(Capability.CONTROL_IMAGES)
|
||||||
|
elif suffix in TEXT_EXTENSIONS:
|
||||||
|
caps.add(Capability.TEXT_GENERATION)
|
||||||
|
all_input_names = list(target.inputs) + list(target.reference_images)
|
||||||
|
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
|
||||||
|
caps.add(Capability.VISION)
|
||||||
|
else:
|
||||||
|
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return frozenset(caps)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model(
|
||||||
|
target_name: str, target: TargetConfig, defaults: Defaults
|
||||||
|
) -> ModelInfo:
|
||||||
|
"""Return the effective model for a target, validated against required capabilities.
|
||||||
|
|
||||||
|
If the target specifies an explicit model, it is validated to have all
|
||||||
|
required capabilities. Otherwise the type-appropriate default is tried
|
||||||
|
first; if it lacks a required capability the first capable model of the
|
||||||
|
same type is selected.
|
||||||
|
|
||||||
|
Raises :class:`ValueError` if no suitable model can be found.
|
||||||
|
"""
|
||||||
|
from bulkgen.providers.registry import get_all_models
|
||||||
|
|
||||||
|
all_models = get_all_models()
|
||||||
|
required = infer_required_capabilities(target_name, target)
|
||||||
|
target_type = target_type_from_capabilities(required)
|
||||||
|
|
||||||
|
if target.model is not None:
|
||||||
|
# Explicit model — look up and validate.
|
||||||
|
for model in all_models:
|
||||||
|
if model.name == target.model:
|
||||||
|
missing = required - frozenset(model.capabilities)
|
||||||
|
if missing:
|
||||||
|
names = ", ".join(sorted(missing))
|
||||||
|
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return model
|
||||||
|
msg = f"Unknown model '{target.model}' for target '{target_name}'"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# No explicit model — try the default first, then fall back.
|
||||||
|
default_name = (
|
||||||
|
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
|
||||||
|
)
|
||||||
|
for model in all_models:
|
||||||
|
if model.name == default_name:
|
||||||
|
if required <= frozenset(model.capabilities):
|
||||||
|
return model
|
||||||
|
break
|
||||||
|
|
||||||
|
# Default lacks capabilities — find the first capable model of the same type.
|
||||||
|
model_type = "image" if target_type is TargetType.IMAGE else "text"
|
||||||
|
for model in all_models:
|
||||||
|
if model.type == model_type and required <= frozenset(model.capabilities):
|
||||||
|
return model
|
||||||
|
|
||||||
|
names = ", ".join(sorted(required))
|
||||||
|
msg = (
|
||||||
|
f"No model found for target '{target_name}' with required capabilities: {names}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
@ -27,6 +27,11 @@ WriteConfig = Callable[[dict[str, object]], ProjectConfig]
|
||||||
class FakeProvider(Provider):
|
class FakeProvider(Provider):
|
||||||
"""A provider that writes a marker file instead of calling an API."""
|
"""A provider that writes a marker file instead of calling an API."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@override
|
||||||
|
def get_provided_models() -> list[ModelInfo]:
|
||||||
|
return []
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -43,6 +48,11 @@ class FakeProvider(Provider):
|
||||||
class FailingProvider(Provider):
|
class FailingProvider(Provider):
|
||||||
"""A provider that always raises."""
|
"""A provider that always raises."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@override
|
||||||
|
def get_provided_models() -> list[ModelInfo]:
|
||||||
|
return []
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -7,14 +7,7 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from bulkgen.config import (
|
from bulkgen.config import load_config
|
||||||
Defaults,
|
|
||||||
TargetConfig,
|
|
||||||
infer_required_capabilities,
|
|
||||||
load_config,
|
|
||||||
resolve_model,
|
|
||||||
)
|
|
||||||
from bulkgen.providers.models import Capability
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoadConfig:
|
class TestLoadConfig:
|
||||||
|
|
@ -86,136 +79,3 @@ class TestLoadConfig:
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
_ = load_config(config_path)
|
_ = load_config(config_path)
|
||||||
|
|
||||||
|
|
||||||
class TestInferRequiredCapabilities:
|
|
||||||
"""Test capability inference from file extensions and target config."""
|
|
||||||
|
|
||||||
def test_plain_image(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x")
|
|
||||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
|
||||||
{Capability.TEXT_TO_IMAGE}
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
|
|
||||||
def test_image_extensions(self, name: str) -> None:
|
|
||||||
target = TargetConfig(prompt="x")
|
|
||||||
caps = infer_required_capabilities(name, target)
|
|
||||||
assert Capability.TEXT_TO_IMAGE in caps
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
|
|
||||||
def test_case_insensitive(self, name: str) -> None:
|
|
||||||
target = TargetConfig(prompt="x")
|
|
||||||
caps = infer_required_capabilities(name, target)
|
|
||||||
assert Capability.TEXT_TO_IMAGE in caps
|
|
||||||
|
|
||||||
def test_image_with_reference_images(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x", reference_images=["ref.png"])
|
|
||||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
|
||||||
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_image_with_control_images(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
|
|
||||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
|
||||||
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_image_with_both(self) -> None:
|
|
||||||
target = TargetConfig(
|
|
||||||
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
|
|
||||||
)
|
|
||||||
assert infer_required_capabilities("out.png", target) == frozenset(
|
|
||||||
{
|
|
||||||
Capability.TEXT_TO_IMAGE,
|
|
||||||
Capability.REFERENCE_IMAGES,
|
|
||||||
Capability.CONTROL_IMAGES,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
|
||||||
def test_text_extensions(self, name: str) -> None:
|
|
||||||
target = TargetConfig(prompt="x")
|
|
||||||
caps = infer_required_capabilities(name, target)
|
|
||||||
assert caps == frozenset({Capability.TEXT_GENERATION})
|
|
||||||
|
|
||||||
def test_text_with_text_inputs(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x", inputs=["data.txt"])
|
|
||||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
|
||||||
{Capability.TEXT_GENERATION}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_text_with_image_input(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
|
||||||
assert infer_required_capabilities("out.txt", target) == frozenset(
|
|
||||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_text_with_image_reference(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
|
|
||||||
assert infer_required_capabilities("out.md", target) == frozenset(
|
|
||||||
{Capability.TEXT_GENERATION, Capability.VISION}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_unsupported_extension_raises(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x")
|
|
||||||
with pytest.raises(ValueError, match="unsupported extension"):
|
|
||||||
_ = infer_required_capabilities("data.csv", target)
|
|
||||||
|
|
||||||
def test_no_extension_raises(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x")
|
|
||||||
with pytest.raises(ValueError, match="unsupported extension"):
|
|
||||||
_ = infer_required_capabilities("Makefile", target)
|
|
||||||
|
|
||||||
|
|
||||||
class TestResolveModel:
|
|
||||||
"""Test model resolution with capability validation."""
|
|
||||||
|
|
||||||
def test_explicit_model_wins(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
|
||||||
result = resolve_model("out.txt", target, Defaults())
|
|
||||||
assert result.name == "mistral-small-latest"
|
|
||||||
|
|
||||||
def test_default_text_model(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x")
|
|
||||||
defaults = Defaults(text_model="mistral-large-latest")
|
|
||||||
result = resolve_model("out.md", target, defaults)
|
|
||||||
assert result.name == "mistral-large-latest"
|
|
||||||
|
|
||||||
def test_default_image_model(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x")
|
|
||||||
defaults = Defaults(image_model="flux-dev")
|
|
||||||
result = resolve_model("out.png", target, defaults)
|
|
||||||
assert result.name == "flux-dev"
|
|
||||||
|
|
||||||
def test_unknown_model_raises(self) -> None:
|
|
||||||
target = TargetConfig(prompt="x", model="nonexistent-model")
|
|
||||||
with pytest.raises(ValueError, match="Unknown model"):
|
|
||||||
_ = resolve_model("out.txt", target, Defaults())
|
|
||||||
|
|
||||||
def test_explicit_model_missing_capability_raises(self) -> None:
|
|
||||||
# flux-dev does not support reference images
|
|
||||||
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
|
|
||||||
with pytest.raises(ValueError, match="lacks required capabilities"):
|
|
||||||
_ = resolve_model("out.png", target, Defaults())
|
|
||||||
|
|
||||||
def test_default_fallback_for_reference_images(self) -> None:
|
|
||||||
# Default flux-dev lacks reference_images, should fall back to a capable model
|
|
||||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
|
||||||
defaults = Defaults(image_model="flux-dev")
|
|
||||||
result = resolve_model("out.png", target, defaults)
|
|
||||||
assert Capability.REFERENCE_IMAGES in result.capabilities
|
|
||||||
|
|
||||||
def test_default_fallback_for_vision(self) -> None:
|
|
||||||
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
|
|
||||||
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
|
||||||
defaults = Defaults(text_model="mistral-large-latest")
|
|
||||||
result = resolve_model("out.txt", target, defaults)
|
|
||||||
assert Capability.VISION in result.capabilities
|
|
||||||
|
|
||||||
def test_default_preferred_when_capable(self) -> None:
|
|
||||||
# Default flux-2-pro already supports reference_images, should be used directly
|
|
||||||
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
|
||||||
defaults = Defaults(image_model="flux-2-pro")
|
|
||||||
result = resolve_model("out.png", target, defaults)
|
|
||||||
assert result.name == "flux-2-pro"
|
|
||||||
|
|
|
||||||
|
|
@ -14,17 +14,18 @@ import pytest
|
||||||
|
|
||||||
from bulkgen.config import TargetConfig
|
from bulkgen.config import TargetConfig
|
||||||
from bulkgen.providers.bfl import BFLResult
|
from bulkgen.providers.bfl import BFLResult
|
||||||
from bulkgen.providers.image import ImageProvider
|
from bulkgen.providers.blackforest import BlackForestProvider
|
||||||
from bulkgen.providers.image import (
|
from bulkgen.providers.blackforest import (
|
||||||
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
|
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
|
||||||
)
|
)
|
||||||
from bulkgen.providers.models import ALL_MODELS, ModelInfo
|
from bulkgen.providers.mistral import MistralProvider
|
||||||
from bulkgen.providers.text import TextProvider
|
from bulkgen.providers.models import ModelInfo
|
||||||
|
from bulkgen.providers.registry import get_all_models
|
||||||
|
|
||||||
|
|
||||||
def _model(name: str) -> ModelInfo:
|
def _model(name: str) -> ModelInfo:
|
||||||
"""Look up a ModelInfo by name."""
|
"""Look up a ModelInfo by name."""
|
||||||
for m in ALL_MODELS:
|
for m in get_all_models():
|
||||||
if m.name == name:
|
if m.name == name:
|
||||||
return m
|
return m
|
||||||
msg = f"Unknown test model: {name}"
|
msg = f"Unknown test model: {name}"
|
||||||
|
|
@ -69,8 +70,8 @@ def _make_text_response(content: str | None) -> MagicMock:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class TestImageProvider:
|
class TestBlackForestProvider:
|
||||||
"""Test ImageProvider with mocked BFL client and HTTP."""
|
"""Test BlackForestProvider with mocked BFL client and HTTP."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def image_bytes(self) -> bytes:
|
def image_bytes(self) -> bytes:
|
||||||
|
|
@ -83,13 +84,13 @@ class TestImageProvider:
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||||
):
|
):
|
||||||
mock_cls.return_value.generate = AsyncMock(return_value=bfl_result)
|
mock_cls.return_value.generate = AsyncMock(return_value=bfl_result)
|
||||||
mock_http_cls.return_value = mock_http
|
mock_http_cls.return_value = mock_http
|
||||||
|
|
||||||
provider = ImageProvider(api_key="test-key")
|
provider = BlackForestProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="out.png",
|
target_name="out.png",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -109,14 +110,14 @@ class TestImageProvider:
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||||
):
|
):
|
||||||
mock_generate = AsyncMock(return_value=bfl_result)
|
mock_generate = AsyncMock(return_value=bfl_result)
|
||||||
mock_cls.return_value.generate = mock_generate
|
mock_cls.return_value.generate = mock_generate
|
||||||
mock_http_cls.return_value = mock_http
|
mock_http_cls.return_value = mock_http
|
||||||
|
|
||||||
provider = ImageProvider(api_key="test-key")
|
provider = BlackForestProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="banner.png",
|
target_name="banner.png",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -140,14 +141,14 @@ class TestImageProvider:
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||||
):
|
):
|
||||||
mock_generate = AsyncMock(return_value=bfl_result)
|
mock_generate = AsyncMock(return_value=bfl_result)
|
||||||
mock_cls.return_value.generate = mock_generate
|
mock_cls.return_value.generate = mock_generate
|
||||||
mock_http_cls.return_value = mock_http
|
mock_http_cls.return_value = mock_http
|
||||||
|
|
||||||
provider = ImageProvider(api_key="test-key")
|
provider = BlackForestProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="out.png",
|
target_name="out.png",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -175,14 +176,14 @@ class TestImageProvider:
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||||
):
|
):
|
||||||
mock_generate = AsyncMock(return_value=bfl_result)
|
mock_generate = AsyncMock(return_value=bfl_result)
|
||||||
mock_cls.return_value.generate = mock_generate
|
mock_cls.return_value.generate = mock_generate
|
||||||
mock_http_cls.return_value = mock_http
|
mock_http_cls.return_value = mock_http
|
||||||
|
|
||||||
provider = ImageProvider(api_key="test-key")
|
provider = BlackForestProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="out.png",
|
target_name="out.png",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -206,14 +207,14 @@ class TestImageProvider:
|
||||||
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
bfl_result, mock_http = _make_bfl_mocks(image_bytes)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("bulkgen.providers.image.BFLClient") as mock_cls,
|
patch("bulkgen.providers.blackforest.BFLClient") as mock_cls,
|
||||||
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
|
patch("bulkgen.providers.blackforest.httpx.AsyncClient") as mock_http_cls,
|
||||||
):
|
):
|
||||||
mock_generate = AsyncMock(return_value=bfl_result)
|
mock_generate = AsyncMock(return_value=bfl_result)
|
||||||
mock_cls.return_value.generate = mock_generate
|
mock_cls.return_value.generate = mock_generate
|
||||||
mock_http_cls.return_value = mock_http
|
mock_http_cls.return_value = mock_http
|
||||||
|
|
||||||
provider = ImageProvider(api_key="test-key")
|
provider = BlackForestProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="out.png",
|
target_name="out.png",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -229,14 +230,14 @@ class TestImageProvider:
|
||||||
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
|
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
|
||||||
target_config = TargetConfig(prompt="x")
|
target_config = TargetConfig(prompt="x")
|
||||||
|
|
||||||
with patch("bulkgen.providers.image.BFLClient") as mock_cls:
|
with patch("bulkgen.providers.blackforest.BFLClient") as mock_cls:
|
||||||
from bulkgen.providers.bfl import BFLError
|
from bulkgen.providers.bfl import BFLError
|
||||||
|
|
||||||
mock_cls.return_value.generate = AsyncMock(
|
mock_cls.return_value.generate = AsyncMock(
|
||||||
side_effect=BFLError("BFL task test ready but no sample URL: {}")
|
side_effect=BFLError("BFL task test ready but no sample URL: {}")
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = ImageProvider(api_key="test-key")
|
provider = BlackForestProvider(api_key="test-key")
|
||||||
with pytest.raises(BFLError, match="no sample URL"):
|
with pytest.raises(BFLError, match="no sample URL"):
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="fail.png",
|
target_name="fail.png",
|
||||||
|
|
@ -255,17 +256,17 @@ class TestImageProvider:
|
||||||
assert base64.b64decode(encoded) == data
|
assert base64.b64decode(encoded) == data
|
||||||
|
|
||||||
|
|
||||||
class TestTextProvider:
|
class TestMistralProvider:
|
||||||
"""Test TextProvider with mocked Mistral client."""
|
"""Test MistralProvider with mocked Mistral client."""
|
||||||
|
|
||||||
async def test_basic_text_generation(self, project_dir: Path) -> None:
|
async def test_basic_text_generation(self, project_dir: Path) -> None:
|
||||||
target_config = TargetConfig(prompt="Write a poem")
|
target_config = TargetConfig(prompt="Write a poem")
|
||||||
response = _make_text_response("Roses are red...")
|
response = _make_text_response("Roses are red...")
|
||||||
|
|
||||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||||
mock_cls.return_value = _make_mistral_mock(response)
|
mock_cls.return_value = _make_mistral_mock(response)
|
||||||
|
|
||||||
provider = TextProvider(api_key="test-key")
|
provider = MistralProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="poem.txt",
|
target_name="poem.txt",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -283,11 +284,11 @@ class TestTextProvider:
|
||||||
target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"])
|
target_config = TargetConfig(prompt="Summarize", inputs=["source.txt"])
|
||||||
response = _make_text_response("Summary: ...")
|
response = _make_text_response("Summary: ...")
|
||||||
|
|
||||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||||
mock_client = _make_mistral_mock(response)
|
mock_client = _make_mistral_mock(response)
|
||||||
mock_cls.return_value = mock_client
|
mock_cls.return_value = mock_client
|
||||||
|
|
||||||
provider = TextProvider(api_key="test-key")
|
provider = MistralProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="summary.md",
|
target_name="summary.md",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -307,11 +308,11 @@ class TestTextProvider:
|
||||||
target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"])
|
target_config = TargetConfig(prompt="Describe this image", inputs=["photo.png"])
|
||||||
response = _make_text_response("A beautiful photo")
|
response = _make_text_response("A beautiful photo")
|
||||||
|
|
||||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||||
mock_client = _make_mistral_mock(response)
|
mock_client = _make_mistral_mock(response)
|
||||||
mock_cls.return_value = mock_client
|
mock_cls.return_value = mock_client
|
||||||
|
|
||||||
provider = TextProvider(api_key="test-key")
|
provider = MistralProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="desc.txt",
|
target_name="desc.txt",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -332,10 +333,10 @@ class TestTextProvider:
|
||||||
response = MagicMock()
|
response = MagicMock()
|
||||||
response.choices = []
|
response.choices = []
|
||||||
|
|
||||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||||
mock_cls.return_value = _make_mistral_mock(response)
|
mock_cls.return_value = _make_mistral_mock(response)
|
||||||
|
|
||||||
provider = TextProvider(api_key="test-key")
|
provider = MistralProvider(api_key="test-key")
|
||||||
with pytest.raises(RuntimeError, match="no choices"):
|
with pytest.raises(RuntimeError, match="no choices"):
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="fail.txt",
|
target_name="fail.txt",
|
||||||
|
|
@ -349,10 +350,10 @@ class TestTextProvider:
|
||||||
target_config = TargetConfig(prompt="x")
|
target_config = TargetConfig(prompt="x")
|
||||||
response = _make_text_response(None)
|
response = _make_text_response(None)
|
||||||
|
|
||||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||||
mock_cls.return_value = _make_mistral_mock(response)
|
mock_cls.return_value = _make_mistral_mock(response)
|
||||||
|
|
||||||
provider = TextProvider(api_key="test-key")
|
provider = MistralProvider(api_key="test-key")
|
||||||
with pytest.raises(RuntimeError, match="empty content"):
|
with pytest.raises(RuntimeError, match="empty content"):
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="fail.txt",
|
target_name="fail.txt",
|
||||||
|
|
@ -372,11 +373,11 @@ class TestTextProvider:
|
||||||
)
|
)
|
||||||
response = _make_text_response("Combined")
|
response = _make_text_response("Combined")
|
||||||
|
|
||||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||||
mock_client = _make_mistral_mock(response)
|
mock_client = _make_mistral_mock(response)
|
||||||
mock_cls.return_value = mock_client
|
mock_cls.return_value = mock_client
|
||||||
|
|
||||||
provider = TextProvider(api_key="test-key")
|
provider = MistralProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="out.md",
|
target_name="out.md",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
@ -403,11 +404,11 @@ class TestTextProvider:
|
||||||
)
|
)
|
||||||
response = _make_text_response("A stylized image")
|
response = _make_text_response("A stylized image")
|
||||||
|
|
||||||
with patch("bulkgen.providers.text.Mistral") as mock_cls:
|
with patch("bulkgen.providers.mistral.Mistral") as mock_cls:
|
||||||
mock_client = _make_mistral_mock(response)
|
mock_client = _make_mistral_mock(response)
|
||||||
mock_cls.return_value = mock_client
|
mock_cls.return_value = mock_client
|
||||||
|
|
||||||
provider = TextProvider(api_key="test-key")
|
provider = MistralProvider(api_key="test-key")
|
||||||
await provider.generate(
|
await provider.generate(
|
||||||
target_name="desc.txt",
|
target_name="desc.txt",
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
|
|
||||||
145
tests/test_resolve.py
Normal file
145
tests/test_resolve.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Integration tests for bulkgen.config."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from bulkgen.config import (
|
||||||
|
Defaults,
|
||||||
|
TargetConfig,
|
||||||
|
)
|
||||||
|
from bulkgen.providers.models import Capability
|
||||||
|
from bulkgen.resolve import infer_required_capabilities, resolve_model
|
||||||
|
|
||||||
|
|
||||||
|
class TestInferRequiredCapabilities:
|
||||||
|
"""Test capability inference from file extensions and target config."""
|
||||||
|
|
||||||
|
def test_plain_image(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
{Capability.TEXT_TO_IMAGE}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name", ["out.png", "out.jpg", "out.jpeg", "out.webp"])
|
||||||
|
def test_image_extensions(self, name: str) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
caps = infer_required_capabilities(name, target)
|
||||||
|
assert Capability.TEXT_TO_IMAGE in caps
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name", ["OUT.PNG", "OUT.JPG"])
|
||||||
|
def test_case_insensitive(self, name: str) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
caps = infer_required_capabilities(name, target)
|
||||||
|
assert Capability.TEXT_TO_IMAGE in caps
|
||||||
|
|
||||||
|
def test_image_with_reference_images(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", reference_images=["ref.png"])
|
||||||
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
{Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_image_with_control_images(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", control_images=["ctrl.png"])
|
||||||
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
{Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_image_with_both(self) -> None:
|
||||||
|
target = TargetConfig(
|
||||||
|
prompt="x", reference_images=["ref.png"], control_images=["ctrl.png"]
|
||||||
|
)
|
||||||
|
assert infer_required_capabilities("out.png", target) == frozenset(
|
||||||
|
{
|
||||||
|
Capability.TEXT_TO_IMAGE,
|
||||||
|
Capability.REFERENCE_IMAGES,
|
||||||
|
Capability.CONTROL_IMAGES,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name", ["doc.md", "doc.txt"])
|
||||||
|
def test_text_extensions(self, name: str) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
caps = infer_required_capabilities(name, target)
|
||||||
|
assert caps == frozenset({Capability.TEXT_GENERATION})
|
||||||
|
|
||||||
|
def test_text_with_text_inputs(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", inputs=["data.txt"])
|
||||||
|
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||||
|
{Capability.TEXT_GENERATION}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_text_with_image_input(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||||
|
assert infer_required_capabilities("out.txt", target) == frozenset(
|
||||||
|
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_text_with_image_reference(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", reference_images=["ref.jpg"])
|
||||||
|
assert infer_required_capabilities("out.md", target) == frozenset(
|
||||||
|
{Capability.TEXT_GENERATION, Capability.VISION}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_unsupported_extension_raises(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
with pytest.raises(ValueError, match="unsupported extension"):
|
||||||
|
_ = infer_required_capabilities("data.csv", target)
|
||||||
|
|
||||||
|
def test_no_extension_raises(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
with pytest.raises(ValueError, match="unsupported extension"):
|
||||||
|
_ = infer_required_capabilities("Makefile", target)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveModel:
|
||||||
|
"""Test model resolution with capability validation."""
|
||||||
|
|
||||||
|
def test_explicit_model_wins(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", model="mistral-small-latest")
|
||||||
|
result = resolve_model("out.txt", target, Defaults())
|
||||||
|
assert result.name == "mistral-small-latest"
|
||||||
|
|
||||||
|
def test_default_text_model(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
defaults = Defaults(text_model="mistral-large-latest")
|
||||||
|
result = resolve_model("out.md", target, defaults)
|
||||||
|
assert result.name == "mistral-large-latest"
|
||||||
|
|
||||||
|
def test_default_image_model(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x")
|
||||||
|
defaults = Defaults(image_model="flux-dev")
|
||||||
|
result = resolve_model("out.png", target, defaults)
|
||||||
|
assert result.name == "flux-dev"
|
||||||
|
|
||||||
|
def test_unknown_model_raises(self) -> None:
|
||||||
|
target = TargetConfig(prompt="x", model="nonexistent-model")
|
||||||
|
with pytest.raises(ValueError, match="Unknown model"):
|
||||||
|
_ = resolve_model("out.txt", target, Defaults())
|
||||||
|
|
||||||
|
def test_explicit_model_missing_capability_raises(self) -> None:
|
||||||
|
# flux-dev does not support reference images
|
||||||
|
target = TargetConfig(prompt="x", model="flux-dev", reference_images=["r.png"])
|
||||||
|
with pytest.raises(ValueError, match="lacks required capabilities"):
|
||||||
|
_ = resolve_model("out.png", target, Defaults())
|
||||||
|
|
||||||
|
def test_default_fallback_for_reference_images(self) -> None:
|
||||||
|
# Default flux-dev lacks reference_images, should fall back to a capable model
|
||||||
|
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||||
|
defaults = Defaults(image_model="flux-dev")
|
||||||
|
result = resolve_model("out.png", target, defaults)
|
||||||
|
assert Capability.REFERENCE_IMAGES in result.capabilities
|
||||||
|
|
||||||
|
def test_default_fallback_for_vision(self) -> None:
|
||||||
|
# Default mistral-large-latest lacks vision, should fall back to a pixtral model
|
||||||
|
target = TargetConfig(prompt="x", inputs=["photo.png"])
|
||||||
|
defaults = Defaults(text_model="mistral-large-latest")
|
||||||
|
result = resolve_model("out.txt", target, defaults)
|
||||||
|
assert Capability.VISION in result.capabilities
|
||||||
|
|
||||||
|
def test_default_preferred_when_capable(self) -> None:
|
||||||
|
# Default flux-2-pro already supports reference_images, should be used directly
|
||||||
|
target = TargetConfig(prompt="x", reference_images=["r.png"])
|
||||||
|
defaults = Defaults(image_model="flux-2-pro")
|
||||||
|
result = resolve_model("out.png", target, defaults)
|
||||||
|
assert result.name == "flux-2-pro"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue