refactor: move model definitions into providers and extract resolve module
Some checks failed
Continuous Integration / Build Package (push) Successful in 30s
Continuous Integration / Lint, Check & Test (push) Failing after 38s

- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider
- Add get_provided_models() abstract method to Provider base class
- Move model lists from models.py into each provider's get_provided_models()
- Add providers/registry.py to aggregate models from all providers
- Extract infer_required_capabilities and resolve_model from config.py to resolve.py
- Update tests to use new names and import paths
This commit is contained in:
Konstantin Fickel 2026-02-15 11:03:57 +01:00
parent dc6a75f5c4
commit d0dac5b1bf
Signed by: kfickel
GPG key ID: A793722F9933C1A5
13 changed files with 432 additions and 390 deletions

View file

@ -12,14 +12,13 @@ from pathlib import Path
from bulkgen.config import (
ProjectConfig,
TargetType,
infer_required_capabilities,
resolve_model,
target_type_from_capabilities,
)
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
from bulkgen.providers import Provider
from bulkgen.providers.image import ImageProvider
from bulkgen.providers.text import TextProvider
from bulkgen.providers.blackforest import BlackForestProvider
from bulkgen.providers.mistral import MistralProvider
from bulkgen.resolve import infer_required_capabilities, resolve_model
from bulkgen.state import (
BuildState,
is_target_dirty,
@ -106,10 +105,10 @@ def _create_providers() -> dict[TargetType, Provider]:
providers: dict[TargetType, Provider] = {}
bfl_key = os.environ.get("BFL_API_KEY", "")
if bfl_key:
providers[TargetType.IMAGE] = ImageProvider(api_key=bfl_key)
providers[TargetType.IMAGE] = BlackForestProvider(api_key=bfl_key)
mistral_key = os.environ.get("MISTRAL_API_KEY", "")
if mistral_key:
providers[TargetType.TEXT] = TextProvider(api_key=mistral_key)
providers[TargetType.TEXT] = MistralProvider(api_key=mistral_key)
return providers

View file

@ -13,7 +13,7 @@ import typer
from bulkgen.builder import BuildEvent, BuildResult, run_build
from bulkgen.config import ProjectConfig, load_config
from bulkgen.graph import build_graph, get_build_order
from bulkgen.providers.models import ALL_MODELS
from bulkgen.providers.registry import get_all_models
app = typer.Typer(name="bulkgen", help="AI artifact build tool.")
@ -182,9 +182,10 @@ def graph() -> None:
@app.command()
def models() -> None:
"""List available models and their capabilities."""
name_width = max(len(m.name) for m in ALL_MODELS)
provider_width = max(len(m.provider) for m in ALL_MODELS)
type_width = max(len(m.type) for m in ALL_MODELS)
all_models = get_all_models()
name_width = max(len(m.name) for m in all_models)
provider_width = max(len(m.provider) for m in all_models)
type_width = max(len(m.type) for m in all_models)
header_name = "Model".ljust(name_width)
header_provider = "Provider".ljust(provider_width)
@ -210,7 +211,7 @@ def models() -> None:
+ "" * len(header_caps)
)
for model in ALL_MODELS:
for model in all_models:
name_col = model.name.ljust(name_width)
provider_col = model.provider.ljust(provider_width)
type_col = model.type.ljust(type_width)

View file

@ -4,13 +4,15 @@ from __future__ import annotations
import enum
from pathlib import Path
from typing import TYPE_CHECKING, Self
from typing import Self
import yaml
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from bulkgen.providers.models import Capability, ModelInfo
from bulkgen.providers.models import Capability
IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"})
TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"})
class TargetType(enum.Enum):
@ -20,10 +22,6 @@ class TargetType(enum.Enum):
TEXT = "text"
IMAGE_EXTENSIONS: frozenset[str] = frozenset({".png", ".jpg", ".jpeg", ".webp"})
TEXT_EXTENSIONS: frozenset[str] = frozenset({".md", ".txt"})
class Defaults(BaseModel):
"""Default model names, applied when a target does not specify its own."""
@ -57,36 +55,6 @@ class ProjectConfig(BaseModel):
return self
def infer_required_capabilities(
target_name: str, target: TargetConfig
) -> frozenset[Capability]:
"""Infer the capabilities a model must have based on filename and config.
Raises :class:`ValueError` for unsupported file extensions.
"""
from bulkgen.providers.models import Capability
suffix = Path(target_name).suffix.lower()
caps: set[Capability] = set()
if suffix in IMAGE_EXTENSIONS:
caps.add(Capability.TEXT_TO_IMAGE)
if target.reference_images:
caps.add(Capability.REFERENCE_IMAGES)
if target.control_images:
caps.add(Capability.CONTROL_IMAGES)
elif suffix in TEXT_EXTENSIONS:
caps.add(Capability.TEXT_GENERATION)
all_input_names = list(target.inputs) + list(target.reference_images)
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
caps.add(Capability.VISION)
else:
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
raise ValueError(msg)
return frozenset(caps)
def target_type_from_capabilities(capabilities: frozenset[Capability]) -> TargetType:
"""Derive the target type from a set of required capabilities."""
from bulkgen.providers.models import Capability
@ -96,59 +64,6 @@ def target_type_from_capabilities(capabilities: frozenset[Capability]) -> Target
return TargetType.TEXT
def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target, validated against required capabilities.
If the target specifies an explicit model, it is validated to have all
required capabilities. Otherwise the type-appropriate default is tried
first; if it lacks a required capability the first capable model of the
same type is selected.
Raises :class:`ValueError` if no suitable model can be found.
"""
from bulkgen.providers.models import ALL_MODELS
required = infer_required_capabilities(target_name, target)
target_type = target_type_from_capabilities(required)
if target.model is not None:
# Explicit model — look up and validate.
for model in ALL_MODELS:
if model.name == target.model:
missing = required - frozenset(model.capabilities)
if missing:
names = ", ".join(sorted(missing))
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
raise ValueError(msg)
return model
msg = f"Unknown model '{target.model}' for target '{target_name}'"
raise ValueError(msg)
# No explicit model — try the default first, then fall back.
default_name = (
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
)
for model in ALL_MODELS:
if model.name == default_name:
if required <= frozenset(model.capabilities):
return model
break
# Default lacks capabilities — find the first capable model of the same type.
model_type = "image" if target_type is TargetType.IMAGE else "text"
for model in ALL_MODELS:
if model.type == model_type and required <= frozenset(model.capabilities):
return model
names = ", ".join(sorted(required))
msg = (
f"No model found for target '{target_name}' with required capabilities: {names}"
)
raise ValueError(msg)
def load_config(config_path: Path) -> ProjectConfig:
"""Load and validate a ``.bulkgen.yaml`` file."""
with config_path.open() as f:

View file

@ -12,6 +12,11 @@ from bulkgen.providers.models import ModelInfo
class Provider(abc.ABC):
"""Abstract base for generation providers."""
@staticmethod
@abc.abstractmethod
def get_provided_models() -> list[ModelInfo]:
"""Return the models this provider supports."""
@abc.abstractmethod
async def generate(
self,

View file

@ -11,7 +11,7 @@ import httpx
from bulkgen.config import TargetConfig
from bulkgen.providers import Provider
from bulkgen.providers.bfl import BFLClient
from bulkgen.providers.models import ModelInfo
from bulkgen.providers.models import Capability, ModelInfo
def _encode_image_b64(path: Path) -> str:
@ -45,7 +45,7 @@ def _add_reference_images(
inputs[key] = _encode_image_b64(project_dir / ref_name)
class ImageProvider(Provider):
class BlackForestProvider(Provider):
"""Generates images via the BlackForestLabs API."""
_client: BFLClient
@ -53,6 +53,72 @@ class ImageProvider(Provider):
def __init__(self, api_key: str) -> None:
self._client = BFLClient(api_key=api_key)
@staticmethod
@override
def get_provided_models() -> list[ModelInfo]:
return [
ModelInfo(
name="flux-dev",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1-ultra",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-2-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-kontext-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-canny",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-depth",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-fill",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.0-expand",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
]
@override
async def generate(
self,

View file

@ -11,7 +11,7 @@ from mistralai import Mistral, models
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
from bulkgen.providers import Provider
from bulkgen.providers.models import ModelInfo
from bulkgen.providers.models import Capability, ModelInfo
def _image_to_data_url(path: Path) -> str:
@ -21,7 +21,7 @@ def _image_to_data_url(path: Path) -> str:
return f"data:{mime};base64,{b64}"
class TextProvider(Provider):
class MistralProvider(Provider):
"""Generates text via the Mistral API."""
_api_key: str
@ -29,6 +29,36 @@ class TextProvider(Provider):
def __init__(self, api_key: str) -> None:
self._api_key = api_key
@staticmethod
@override
def get_provided_models() -> list[ModelInfo]:
return [
ModelInfo(
name="mistral-large-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION],
),
ModelInfo(
name="mistral-small-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION],
),
ModelInfo(
name="pixtral-large-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
),
ModelInfo(
name="pixtral-12b-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
),
]
@override
async def generate(
self,
@ -98,6 +128,8 @@ def _build_multimodal_message(
models.TextChunk(text=prompt),
]
from bulkgen.config import IMAGE_EXTENSIONS
for name in input_names:
input_path = project_dir / name
suffix = input_path.suffix.lower()

View file

@ -1,4 +1,4 @@
"""Registry of supported models and their capabilities."""
"""Model types and capability definitions for AI providers."""
from __future__ import annotations
@ -25,106 +25,3 @@ class ModelInfo:
provider: str
type: Literal["text", "image"]
capabilities: list[Capability]
TEXT_MODELS: list[ModelInfo] = [
ModelInfo(
name="mistral-large-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION],
),
ModelInfo(
name="mistral-small-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION],
),
ModelInfo(
name="pixtral-large-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
),
ModelInfo(
name="pixtral-12b-latest",
provider="Mistral",
type="text",
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
),
]
IMAGE_MODELS: list[ModelInfo] = [
ModelInfo(
name="flux-dev",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1-ultra",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-2-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-kontext-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-canny",
provider="BlackForestLabs",
type="image",
capabilities=[
Capability.TEXT_TO_IMAGE,
Capability.CONTROL_IMAGES,
],
),
ModelInfo(
name="flux-pro-1.0-depth",
provider="BlackForestLabs",
type="image",
capabilities=[
Capability.TEXT_TO_IMAGE,
Capability.CONTROL_IMAGES,
],
),
ModelInfo(
name="flux-pro-1.0-fill",
provider="BlackForestLabs",
type="image",
capabilities=[
Capability.TEXT_TO_IMAGE,
],
),
ModelInfo(
name="flux-pro-1.0-expand",
provider="BlackForestLabs",
type="image",
capabilities=[
Capability.TEXT_TO_IMAGE,
],
),
]
ALL_MODELS: list[ModelInfo] = TEXT_MODELS + IMAGE_MODELS

View file

@ -0,0 +1,16 @@
"""Aggregates models from all registered providers."""
from __future__ import annotations
from bulkgen.providers.models import ModelInfo
def get_all_models() -> list[ModelInfo]:
"""Return the merged list of models from all providers."""
from bulkgen.providers.blackforest import BlackForestProvider
from bulkgen.providers.mistral import MistralProvider
return (
MistralProvider.get_provided_models()
+ BlackForestProvider.get_provided_models()
)

95
bulkgen/resolve.py Normal file
View file

@ -0,0 +1,95 @@
from __future__ import annotations
from pathlib import Path
from bulkgen.config import (
IMAGE_EXTENSIONS,
TEXT_EXTENSIONS,
Defaults,
TargetConfig,
TargetType,
target_type_from_capabilities,
)
from bulkgen.providers.models import Capability, ModelInfo
def infer_required_capabilities(
target_name: str, target: TargetConfig
) -> frozenset[Capability]:
"""Infer the capabilities a model must have based on filename and config.
Raises :class:`ValueError` for unsupported file extensions.
"""
suffix = Path(target_name).suffix.lower()
caps: set[Capability] = set()
if suffix in IMAGE_EXTENSIONS:
caps.add(Capability.TEXT_TO_IMAGE)
if target.reference_images:
caps.add(Capability.REFERENCE_IMAGES)
if target.control_images:
caps.add(Capability.CONTROL_IMAGES)
elif suffix in TEXT_EXTENSIONS:
caps.add(Capability.TEXT_GENERATION)
all_input_names = list(target.inputs) + list(target.reference_images)
if any(Path(n).suffix.lower() in IMAGE_EXTENSIONS for n in all_input_names):
caps.add(Capability.VISION)
else:
msg = f"Cannot infer target type for '{target_name}': unsupported extension '{suffix}'"
raise ValueError(msg)
return frozenset(caps)
def resolve_model(
target_name: str, target: TargetConfig, defaults: Defaults
) -> ModelInfo:
"""Return the effective model for a target, validated against required capabilities.
If the target specifies an explicit model, it is validated to have all
required capabilities. Otherwise the type-appropriate default is tried
first; if it lacks a required capability the first capable model of the
same type is selected.
Raises :class:`ValueError` if no suitable model can be found.
"""
from bulkgen.providers.registry import get_all_models
all_models = get_all_models()
required = infer_required_capabilities(target_name, target)
target_type = target_type_from_capabilities(required)
if target.model is not None:
# Explicit model — look up and validate.
for model in all_models:
if model.name == target.model:
missing = required - frozenset(model.capabilities)
if missing:
names = ", ".join(sorted(missing))
msg = f"Model '{target.model}' for target '{target_name}' lacks required capabilities: {names}"
raise ValueError(msg)
return model
msg = f"Unknown model '{target.model}' for target '{target_name}'"
raise ValueError(msg)
# No explicit model — try the default first, then fall back.
default_name = (
defaults.image_model if target_type is TargetType.IMAGE else defaults.text_model
)
for model in all_models:
if model.name == default_name:
if required <= frozenset(model.capabilities):
return model
break
# Default lacks capabilities — find the first capable model of the same type.
model_type = "image" if target_type is TargetType.IMAGE else "text"
for model in all_models:
if model.type == model_type and required <= frozenset(model.capabilities):
return model
names = ", ".join(sorted(required))
msg = (
f"No model found for target '{target_name}' with required capabilities: {names}"
)
raise ValueError(msg)