"""Async client for the BlackForestLabs image generation API.""" from __future__ import annotations import asyncio from dataclasses import dataclass from typing import cast import httpx _BASE_URL = "https://api.bfl.ai" _DEFAULT_POLL_INTERVAL = 1.0 _DEFAULT_TIMEOUT = 300.0 @dataclass class BFLResult: """Successful generation result from the BFL API.""" task_id: str sample_url: str class BFLError(Exception): """Error returned by the BFL API.""" class BFLClient: """Async client for the BlackForestLabs image generation API. Submits generation requests and polls for results. """ _api_key: str _base_url: str _poll_interval: float _timeout: float def __init__( self, api_key: str, *, base_url: str = _BASE_URL, poll_interval: float = _DEFAULT_POLL_INTERVAL, timeout: float = _DEFAULT_TIMEOUT, ) -> None: self._api_key = api_key self._base_url = base_url self._poll_interval = poll_interval self._timeout = timeout async def generate(self, model: str, inputs: dict[str, object]) -> BFLResult: """Submit a generation request and poll until the result is ready. Args: model: BFL model name (e.g. ``flux-pro-1.1``, ``flux-2-pro``). inputs: Model-specific parameters (prompt, dimensions, etc.). Returns: A :class:`BFLResult` containing the task ID and sample image URL. Raises: BFLError: If the API returns an error or the request times out. """ headers = { "x-key": self._api_key, "Content-Type": "application/json", "Accept": "application/json", } async with httpx.AsyncClient( base_url=self._base_url, headers=headers ) as client: task_id, polling_url = await self._submit(client, model, inputs) sample_url = await self._poll(client, task_id, polling_url) return BFLResult(task_id=task_id, sample_url=sample_url) async def _submit( self, client: httpx.AsyncClient, model: str, inputs: dict[str, object], ) -> tuple[str, str]: """POST the generation request and return (task_id, polling_url).""" response = await client.post(f"/v1/{model}", json=inputs) if response.status_code == 422: # noqa: PLR2004 body = cast(dict[str, object], response.json()) detail = body.get("detail", body) msg = f"BFL validation error for model '{model}': {detail}" raise BFLError(msg) if response.status_code != 200: # noqa: PLR2004 msg = f"BFL API returned status {response.status_code}: {response.text}" raise BFLError(msg) body = cast(dict[str, object], response.json()) task_id = body.get("id") polling_url = body.get("polling_url") if not isinstance(task_id, str) or not isinstance(polling_url, str): msg = f"BFL API response missing 'id' or 'polling_url': {body}" raise BFLError(msg) return task_id, polling_url async def _poll( self, client: httpx.AsyncClient, task_id: str, polling_url: str, ) -> str: """Poll the task until ready and return the sample image URL.""" elapsed = 0.0 while elapsed < self._timeout: response = await client.get(polling_url) if response.status_code != 200: # noqa: PLR2004 msg = f"BFL polling returned status {response.status_code}: {response.text}" raise BFLError(msg) body = cast(dict[str, object], response.json()) status = body.get("status") if status == "Ready": result = body.get("result") if not isinstance(result, dict): msg = f"BFL task {task_id} ready but no result dict: {body}" raise BFLError(msg) result_dict = cast(dict[str, object], result) sample = result_dict.get("sample") if not isinstance(sample, str): msg = f"BFL task {task_id} ready but no sample URL: {result}" raise BFLError(msg) return sample if status in ("Error", "Failed"): error_msg = body.get("result", body) msg = f"BFL task {task_id} failed: {error_msg}" raise BFLError(msg) await asyncio.sleep(self._poll_interval) elapsed += self._poll_interval msg = f"BFL task {task_id} timed out after {self._timeout}s" raise BFLError(msg)