"""BlackForestLabs image generation provider.""" from __future__ import annotations import asyncio import base64 from pathlib import Path from typing import override import httpx from blackforest import BFLClient # pyright: ignore[reportMissingTypeStubs] from blackforest.types.general.client_config import ( # pyright: ignore[reportMissingTypeStubs] ClientConfig, ) from blackforest.types.responses.responses import ( # pyright: ignore[reportMissingTypeStubs] SyncResponse, ) from bulkgen.config import TargetConfig from bulkgen.providers import Provider _BFL_SYNC_CONFIG = ClientConfig(sync=True, timeout=300) def _encode_image_b64(path: Path) -> str: """Read an image file and return its base64-encoded representation.""" return base64.b64encode(path.read_bytes()).decode("ascii") class ImageProvider(Provider): """Generates images via the BlackForestLabs API.""" _client: BFLClient def __init__(self, api_key: str) -> None: self._client = BFLClient(api_key=api_key) @override async def generate( self, target_name: str, target_config: TargetConfig, resolved_prompt: str, resolved_model: str, project_dir: Path, ) -> None: output_path = project_dir / target_name inputs: dict[str, object] = {"prompt": resolved_prompt} if target_config.width is not None: inputs["width"] = target_config.width if target_config.height is not None: inputs["height"] = target_config.height if target_config.reference_image is not None: ref_path = project_dir / target_config.reference_image inputs["image_prompt"] = _encode_image_b64(ref_path) for control_name in target_config.control_images: ctrl_path = project_dir / control_name inputs["control_image"] = _encode_image_b64(ctrl_path) result = await asyncio.to_thread( self._client.generate, resolved_model, inputs, _BFL_SYNC_CONFIG ) if not isinstance(result, SyncResponse): msg = ( f"BFL API returned unexpected response type for target '{target_name}'" ) raise RuntimeError(msg) result_dict: dict[str, str] = result.result # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] image_url = result_dict.get("sample") if not image_url: msg = f"BFL API did not return an image URL for target '{target_name}'" raise RuntimeError(msg) async with httpx.AsyncClient() as http: response = await http.get(image_url) _ = response.raise_for_status() _ = output_path.write_bytes(response.content)