refactor: move model definitions into providers and extract resolve module
- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider - Add get_provided_models() abstract method to Provider base class - Move model lists from models.py into each provider's get_provided_models() - Add providers/registry.py to aggregate models from all providers - Extract infer_required_capabilities and resolve_model from config.py to resolve.py - Update tests to use new names and import paths
This commit is contained in:
parent
dc6a75f5c4
commit
d0dac5b1bf
13 changed files with 432 additions and 390 deletions
|
|
@ -12,6 +12,11 @@ from bulkgen.providers.models import ModelInfo
|
|||
class Provider(abc.ABC):
|
||||
"""Abstract base for generation providers."""
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def get_provided_models() -> list[ModelInfo]:
|
||||
"""Return the models this provider supports."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import httpx
|
|||
from bulkgen.config import TargetConfig
|
||||
from bulkgen.providers import Provider
|
||||
from bulkgen.providers.bfl import BFLClient
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.providers.models import Capability, ModelInfo
|
||||
|
||||
|
||||
def _encode_image_b64(path: Path) -> str:
|
||||
|
|
@ -45,7 +45,7 @@ def _add_reference_images(
|
|||
inputs[key] = _encode_image_b64(project_dir / ref_name)
|
||||
|
||||
|
||||
class ImageProvider(Provider):
|
||||
class BlackForestProvider(Provider):
|
||||
"""Generates images via the BlackForestLabs API."""
|
||||
|
||||
_client: BFLClient
|
||||
|
|
@ -53,6 +53,72 @@ class ImageProvider(Provider):
|
|||
def __init__(self, api_key: str) -> None:
|
||||
self._client = BFLClient(api_key=api_key)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_provided_models() -> list[ModelInfo]:
|
||||
return [
|
||||
ModelInfo(
|
||||
name="flux-dev",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.1",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.1-ultra",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-2-pro",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-kontext-pro",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.0-canny",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.0-depth",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.0-fill",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.0-expand",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
]
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
|
|
@ -11,7 +11,7 @@ from mistralai import Mistral, models
|
|||
|
||||
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
|
||||
from bulkgen.providers import Provider
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.providers.models import Capability, ModelInfo
|
||||
|
||||
|
||||
def _image_to_data_url(path: Path) -> str:
|
||||
|
|
@ -21,7 +21,7 @@ def _image_to_data_url(path: Path) -> str:
|
|||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
|
||||
class TextProvider(Provider):
|
||||
class MistralProvider(Provider):
|
||||
"""Generates text via the Mistral API."""
|
||||
|
||||
_api_key: str
|
||||
|
|
@ -29,6 +29,36 @@ class TextProvider(Provider):
|
|||
def __init__(self, api_key: str) -> None:
|
||||
self._api_key = api_key
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_provided_models() -> list[ModelInfo]:
|
||||
return [
|
||||
ModelInfo(
|
||||
name="mistral-large-latest",
|
||||
provider="Mistral",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION],
|
||||
),
|
||||
ModelInfo(
|
||||
name="mistral-small-latest",
|
||||
provider="Mistral",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION],
|
||||
),
|
||||
ModelInfo(
|
||||
name="pixtral-large-latest",
|
||||
provider="Mistral",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
||||
),
|
||||
ModelInfo(
|
||||
name="pixtral-12b-latest",
|
||||
provider="Mistral",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
||||
),
|
||||
]
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
|
|
@ -98,6 +128,8 @@ def _build_multimodal_message(
|
|||
models.TextChunk(text=prompt),
|
||||
]
|
||||
|
||||
from bulkgen.config import IMAGE_EXTENSIONS
|
||||
|
||||
for name in input_names:
|
||||
input_path = project_dir / name
|
||||
suffix = input_path.suffix.lower()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Registry of supported models and their capabilities."""
|
||||
"""Model types and capability definitions for AI providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -25,106 +25,3 @@ class ModelInfo:
|
|||
provider: str
|
||||
type: Literal["text", "image"]
|
||||
capabilities: list[Capability]
|
||||
|
||||
|
||||
TEXT_MODELS: list[ModelInfo] = [
|
||||
ModelInfo(
|
||||
name="mistral-large-latest",
|
||||
provider="Mistral",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION],
|
||||
),
|
||||
ModelInfo(
|
||||
name="mistral-small-latest",
|
||||
provider="Mistral",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION],
|
||||
),
|
||||
ModelInfo(
|
||||
name="pixtral-large-latest",
|
||||
provider="Mistral",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
||||
),
|
||||
ModelInfo(
|
||||
name="pixtral-12b-latest",
|
||||
provider="Mistral",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
||||
),
|
||||
]
|
||||
|
||||
IMAGE_MODELS: list[ModelInfo] = [
|
||||
ModelInfo(
|
||||
name="flux-dev",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.1",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.1-ultra",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-2-pro",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-kontext-pro",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.0-canny",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[
|
||||
Capability.TEXT_TO_IMAGE,
|
||||
Capability.CONTROL_IMAGES,
|
||||
],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.0-depth",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[
|
||||
Capability.TEXT_TO_IMAGE,
|
||||
Capability.CONTROL_IMAGES,
|
||||
],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.0-fill",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[
|
||||
Capability.TEXT_TO_IMAGE,
|
||||
],
|
||||
),
|
||||
ModelInfo(
|
||||
name="flux-pro-1.0-expand",
|
||||
provider="BlackForestLabs",
|
||||
type="image",
|
||||
capabilities=[
|
||||
Capability.TEXT_TO_IMAGE,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
ALL_MODELS: list[ModelInfo] = TEXT_MODELS + IMAGE_MODELS
|
||||
|
|
|
|||
16
bulkgen/providers/registry.py
Normal file
16
bulkgen/providers/registry.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Aggregates models from all registered providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
|
||||
|
||||
def get_all_models() -> list[ModelInfo]:
|
||||
"""Return the merged list of models from all providers."""
|
||||
from bulkgen.providers.blackforest import BlackForestProvider
|
||||
from bulkgen.providers.mistral import MistralProvider
|
||||
|
||||
return (
|
||||
MistralProvider.get_provided_models()
|
||||
+ BlackForestProvider.get_provided_models()
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue