hokusai/bulkgen/providers/blackforest.py
Konstantin Fickel d0dac5b1bf
Some checks failed
Continuous Integration / Build Package (push) Successful in 30s
Continuous Integration / Lint, Check & Test (push) Failing after 38s
refactor: move model definitions into providers and extract resolve module
- Rename ImageProvider to BlackForestProvider, TextProvider to MistralProvider
- Add get_provided_models() abstract method to Provider base class
- Move model lists from models.py into each provider's get_provided_models()
- Add providers/registry.py to aggregate models from all providers
- Extract infer_required_capabilities and resolve_model from config.py to resolve.py
- Update tests to use new names and import paths
2026-02-15 11:03:57 +01:00

155 lines
5 KiB
Python

"""BlackForestLabs image generation provider."""
from __future__ import annotations
import base64
from pathlib import Path
from typing import override
import httpx
from bulkgen.config import TargetConfig
from bulkgen.providers import Provider
from bulkgen.providers.bfl import BFLClient
from bulkgen.providers.models import Capability, ModelInfo
def _encode_image_b64(path: Path) -> str:
"""Read an image file and return its base64-encoded representation."""
return base64.b64encode(path.read_bytes()).decode("ascii")
# Parameter names for reference images, keyed by model prefix.
_INPUT_IMAGE_KEYS = ["input_image"] + [f"input_image_{i}" for i in range(2, 9)]
_IMAGE_PROMPT_KEYS = ["image_prompt"]
def _ref_image_keys(model: str) -> list[str]:
"""Return the ordered API parameter names for reference images."""
if model.startswith("flux-2-"):
return _INPUT_IMAGE_KEYS # up to 8
if model.startswith("flux-kontext-"):
return _INPUT_IMAGE_KEYS[:4] # up to 4
return _IMAGE_PROMPT_KEYS # flux 1.x: single image_prompt
def _add_reference_images(
inputs: dict[str, object],
reference_images: list[str],
model: str,
project_dir: Path,
) -> None:
"""Encode reference images and add them under the correct API keys."""
keys = _ref_image_keys(model)
for key, ref_name in zip(keys, reference_images, strict=False):
inputs[key] = _encode_image_b64(project_dir / ref_name)
class BlackForestProvider(Provider):
"""Generates images via the BlackForestLabs API."""
_client: BFLClient
def __init__(self, api_key: str) -> None:
self._client = BFLClient(api_key=api_key)
@staticmethod
@override
def get_provided_models() -> list[ModelInfo]:
return [
ModelInfo(
name="flux-dev",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.1-ultra",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-2-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-kontext-pro",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-canny",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-depth",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE, Capability.CONTROL_IMAGES],
),
ModelInfo(
name="flux-pro-1.0-fill",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
ModelInfo(
name="flux-pro-1.0-expand",
provider="BlackForestLabs",
type="image",
capabilities=[Capability.TEXT_TO_IMAGE],
),
]
@override
async def generate(
self,
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
output_path = project_dir / target_name
inputs: dict[str, object] = {"prompt": resolved_prompt}
if target_config.width is not None:
inputs["width"] = target_config.width
if target_config.height is not None:
inputs["height"] = target_config.height
if target_config.reference_images:
_add_reference_images(
inputs, target_config.reference_images, resolved_model.name, project_dir
)
for control_name in target_config.control_images:
ctrl_path = project_dir / control_name
inputs["control_image"] = _encode_image_b64(ctrl_path)
result = await self._client.generate(resolved_model.name, inputs)
async with httpx.AsyncClient() as http:
response = await http.get(result.sample_url)
_ = response.raise_for_status()
_ = output_path.write_bytes(response.content)