"""BlackForestLabs image generation provider.""" from __future__ import annotations import base64 from pathlib import Path from typing import override import httpx from bulkgen.config import TargetConfig from bulkgen.providers import Provider from bulkgen.providers.bfl import BFLClient from bulkgen.providers.models import ModelInfo def _encode_image_b64(path: Path) -> str: """Read an image file and return its base64-encoded representation.""" return base64.b64encode(path.read_bytes()).decode("ascii") # Parameter names for reference images, keyed by model prefix. _INPUT_IMAGE_KEYS = ["input_image"] + [f"input_image_{i}" for i in range(2, 9)] _IMAGE_PROMPT_KEYS = ["image_prompt"] def _ref_image_keys(model: str) -> list[str]: """Return the ordered API parameter names for reference images.""" if model.startswith("flux-2-"): return _INPUT_IMAGE_KEYS # up to 8 if model.startswith("flux-kontext-"): return _INPUT_IMAGE_KEYS[:4] # up to 4 return _IMAGE_PROMPT_KEYS # flux 1.x: single image_prompt def _add_reference_images( inputs: dict[str, object], reference_images: list[str], model: str, project_dir: Path, ) -> None: """Encode reference images and add them under the correct API keys.""" keys = _ref_image_keys(model) for key, ref_name in zip(keys, reference_images, strict=False): inputs[key] = _encode_image_b64(project_dir / ref_name) class ImageProvider(Provider): """Generates images via the BlackForestLabs API.""" _client: BFLClient def __init__(self, api_key: str) -> None: self._client = BFLClient(api_key=api_key) @override async def generate( self, target_name: str, target_config: TargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, ) -> None: output_path = project_dir / target_name inputs: dict[str, object] = {"prompt": resolved_prompt} if target_config.width is not None: inputs["width"] = target_config.width if target_config.height is not None: inputs["height"] = target_config.height if target_config.reference_images: _add_reference_images( inputs, target_config.reference_images, resolved_model.name, project_dir ) for control_name in target_config.control_images: ctrl_path = project_dir / control_name inputs["control_image"] = _encode_image_b64(ctrl_path) result = await self._client.generate(resolved_model.name, inputs) async with httpx.AsyncClient() as http: response = await http.get(result.sample_url) _ = response.raise_for_status() _ = output_path.write_bytes(response.content)