"""BlackForestLabs image generation provider.""" from __future__ import annotations import base64 from pathlib import Path from typing import override import httpx from bulkgen.config import TargetConfig from bulkgen.providers import Provider from bulkgen.providers.bfl import BFLClient def _encode_image_b64(path: Path) -> str: """Read an image file and return its base64-encoded representation.""" return base64.b64encode(path.read_bytes()).decode("ascii") # Parameter names for reference images, keyed by model prefix. _INPUT_IMAGE_KEYS = ["input_image"] + [f"input_image_{i}" for i in range(2, 9)] _IMAGE_PROMPT_KEYS = ["image_prompt"] def _ref_image_keys(model: str) -> list[str]: """Return the ordered API parameter names for reference images.""" if model.startswith("flux-2-"): return _INPUT_IMAGE_KEYS # up to 8 if model.startswith("flux-kontext-"): return _INPUT_IMAGE_KEYS[:4] # up to 4 return _IMAGE_PROMPT_KEYS # flux 1.x: single image_prompt def _add_reference_images( inputs: dict[str, object], reference_images: list[str], model: str, project_dir: Path, ) -> None: """Encode reference images and add them under the correct API keys.""" keys = _ref_image_keys(model) for key, ref_name in zip(keys, reference_images, strict=False): inputs[key] = _encode_image_b64(project_dir / ref_name) class ImageProvider(Provider): """Generates images via the BlackForestLabs API.""" _client: BFLClient def __init__(self, api_key: str) -> None: self._client = BFLClient(api_key=api_key) @override async def generate( self, target_name: str, target_config: TargetConfig, resolved_prompt: str, resolved_model: str, project_dir: Path, ) -> None: output_path = project_dir / target_name inputs: dict[str, object] = {"prompt": resolved_prompt} if target_config.width is not None: inputs["width"] = target_config.width if target_config.height is not None: inputs["height"] = target_config.height if target_config.reference_images: _add_reference_images( inputs, target_config.reference_images, resolved_model, project_dir ) for control_name in target_config.control_images: ctrl_path = project_dir / control_name inputs["control_image"] = _encode_image_b64(ctrl_path) result = await self._client.generate(resolved_model, inputs) async with httpx.AsyncClient() as http: response = await http.get(result.sample_url) _ = response.raise_for_status() _ = output_path.write_bytes(response.content)