Implement bulkgen/providers/bfl.py with a fully async httpx-based client that supports all current and future BFL models (including flux-2-*). Remove the blackforest dependency and simplify the image provider by eliminating the asyncio.to_thread wrapper.
148 lines
4.6 KiB
Python
148 lines
4.6 KiB
Python
"""Async client for the BlackForestLabs image generation API."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass
|
|
from typing import cast
|
|
|
|
import httpx
|
|
|
|
_BASE_URL = "https://api.bfl.ai"
|
|
_DEFAULT_POLL_INTERVAL = 1.0
|
|
_DEFAULT_TIMEOUT = 300.0
|
|
|
|
|
|
@dataclass
|
|
class BFLResult:
|
|
"""Successful generation result from the BFL API."""
|
|
|
|
task_id: str
|
|
sample_url: str
|
|
|
|
|
|
class BFLError(Exception):
|
|
"""Error returned by the BFL API."""
|
|
|
|
|
|
class BFLClient:
|
|
"""Async client for the BlackForestLabs image generation API.
|
|
|
|
Submits generation requests and polls for results.
|
|
"""
|
|
|
|
_api_key: str
|
|
_base_url: str
|
|
_poll_interval: float
|
|
_timeout: float
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
*,
|
|
base_url: str = _BASE_URL,
|
|
poll_interval: float = _DEFAULT_POLL_INTERVAL,
|
|
timeout: float = _DEFAULT_TIMEOUT,
|
|
) -> None:
|
|
self._api_key = api_key
|
|
self._base_url = base_url
|
|
self._poll_interval = poll_interval
|
|
self._timeout = timeout
|
|
|
|
async def generate(self, model: str, inputs: dict[str, object]) -> BFLResult:
|
|
"""Submit a generation request and poll until the result is ready.
|
|
|
|
Args:
|
|
model: BFL model name (e.g. ``flux-pro-1.1``, ``flux-2-pro``).
|
|
inputs: Model-specific parameters (prompt, dimensions, etc.).
|
|
|
|
Returns:
|
|
A :class:`BFLResult` containing the task ID and sample image URL.
|
|
|
|
Raises:
|
|
BFLError: If the API returns an error or the request times out.
|
|
"""
|
|
headers = {
|
|
"x-key": self._api_key,
|
|
"Content-Type": "application/json",
|
|
"Accept": "application/json",
|
|
}
|
|
|
|
async with httpx.AsyncClient(
|
|
base_url=self._base_url, headers=headers
|
|
) as client:
|
|
task_id, polling_url = await self._submit(client, model, inputs)
|
|
sample_url = await self._poll(client, task_id, polling_url)
|
|
|
|
return BFLResult(task_id=task_id, sample_url=sample_url)
|
|
|
|
async def _submit(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
model: str,
|
|
inputs: dict[str, object],
|
|
) -> tuple[str, str]:
|
|
"""POST the generation request and return (task_id, polling_url)."""
|
|
response = await client.post(f"/v1/{model}", json=inputs)
|
|
|
|
if response.status_code == 422: # noqa: PLR2004
|
|
body = cast(dict[str, object], response.json())
|
|
detail = body.get("detail", body)
|
|
msg = f"BFL validation error for model '{model}': {detail}"
|
|
raise BFLError(msg)
|
|
|
|
if response.status_code != 200: # noqa: PLR2004
|
|
msg = f"BFL API returned status {response.status_code}: {response.text}"
|
|
raise BFLError(msg)
|
|
|
|
body = cast(dict[str, object], response.json())
|
|
task_id = body.get("id")
|
|
polling_url = body.get("polling_url")
|
|
|
|
if not isinstance(task_id, str) or not isinstance(polling_url, str):
|
|
msg = f"BFL API response missing 'id' or 'polling_url': {body}"
|
|
raise BFLError(msg)
|
|
|
|
return task_id, polling_url
|
|
|
|
async def _poll(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
task_id: str,
|
|
polling_url: str,
|
|
) -> str:
|
|
"""Poll the task until ready and return the sample image URL."""
|
|
elapsed = 0.0
|
|
|
|
while elapsed < self._timeout:
|
|
response = await client.get(polling_url)
|
|
|
|
if response.status_code != 200: # noqa: PLR2004
|
|
msg = f"BFL polling returned status {response.status_code}: {response.text}"
|
|
raise BFLError(msg)
|
|
|
|
body = cast(dict[str, object], response.json())
|
|
status = body.get("status")
|
|
|
|
if status == "Ready":
|
|
result = body.get("result")
|
|
if not isinstance(result, dict):
|
|
msg = f"BFL task {task_id} ready but no result dict: {body}"
|
|
raise BFLError(msg)
|
|
result_dict = cast(dict[str, object], result)
|
|
sample = result_dict.get("sample")
|
|
if not isinstance(sample, str):
|
|
msg = f"BFL task {task_id} ready but no sample URL: {result}"
|
|
raise BFLError(msg)
|
|
return sample
|
|
|
|
if status in ("Error", "Failed"):
|
|
error_msg = body.get("result", body)
|
|
msg = f"BFL task {task_id} failed: {error_msg}"
|
|
raise BFLError(msg)
|
|
|
|
await asyncio.sleep(self._poll_interval)
|
|
elapsed += self._poll_interval
|
|
|
|
msg = f"BFL task {task_id} timed out after {self._timeout}s"
|
|
raise BFLError(msg)
|