refactor: replace blackforest package with custom async BFL client
Implement bulkgen/providers/bfl.py with a fully async httpx-based client that supports all current and future BFL models (including flux-2-*). Remove the blackforest dependency and simplify the image provider by eliminating the asyncio.to_thread wrapper.
This commit is contained in:
parent
fd09d127f2
commit
cf73511876
5 changed files with 179 additions and 89 deletions
148
bulkgen/providers/bfl.py
Normal file
148
bulkgen/providers/bfl.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Async client for the BlackForestLabs image generation API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
|
||||
_BASE_URL = "https://api.bfl.ai"
|
||||
_DEFAULT_POLL_INTERVAL = 1.0
|
||||
_DEFAULT_TIMEOUT = 300.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BFLResult:
|
||||
"""Successful generation result from the BFL API."""
|
||||
|
||||
task_id: str
|
||||
sample_url: str
|
||||
|
||||
|
||||
class BFLError(Exception):
|
||||
"""Error returned by the BFL API."""
|
||||
|
||||
|
||||
class BFLClient:
|
||||
"""Async client for the BlackForestLabs image generation API.
|
||||
|
||||
Submits generation requests and polls for results.
|
||||
"""
|
||||
|
||||
_api_key: str
|
||||
_base_url: str
|
||||
_poll_interval: float
|
||||
_timeout: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
*,
|
||||
base_url: str = _BASE_URL,
|
||||
poll_interval: float = _DEFAULT_POLL_INTERVAL,
|
||||
timeout: float = _DEFAULT_TIMEOUT,
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self._poll_interval = poll_interval
|
||||
self._timeout = timeout
|
||||
|
||||
async def generate(self, model: str, inputs: dict[str, object]) -> BFLResult:
|
||||
"""Submit a generation request and poll until the result is ready.
|
||||
|
||||
Args:
|
||||
model: BFL model name (e.g. ``flux-pro-1.1``, ``flux-2-pro``).
|
||||
inputs: Model-specific parameters (prompt, dimensions, etc.).
|
||||
|
||||
Returns:
|
||||
A :class:`BFLResult` containing the task ID and sample image URL.
|
||||
|
||||
Raises:
|
||||
BFLError: If the API returns an error or the request times out.
|
||||
"""
|
||||
headers = {
|
||||
"x-key": self._api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
base_url=self._base_url, headers=headers
|
||||
) as client:
|
||||
task_id, polling_url = await self._submit(client, model, inputs)
|
||||
sample_url = await self._poll(client, task_id, polling_url)
|
||||
|
||||
return BFLResult(task_id=task_id, sample_url=sample_url)
|
||||
|
||||
async def _submit(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
model: str,
|
||||
inputs: dict[str, object],
|
||||
) -> tuple[str, str]:
|
||||
"""POST the generation request and return (task_id, polling_url)."""
|
||||
response = await client.post(f"/v1/{model}", json=inputs)
|
||||
|
||||
if response.status_code == 422: # noqa: PLR2004
|
||||
body = cast(dict[str, object], response.json())
|
||||
detail = body.get("detail", body)
|
||||
msg = f"BFL validation error for model '{model}': {detail}"
|
||||
raise BFLError(msg)
|
||||
|
||||
if response.status_code != 200: # noqa: PLR2004
|
||||
msg = f"BFL API returned status {response.status_code}: {response.text}"
|
||||
raise BFLError(msg)
|
||||
|
||||
body = cast(dict[str, object], response.json())
|
||||
task_id = body.get("id")
|
||||
polling_url = body.get("polling_url")
|
||||
|
||||
if not isinstance(task_id, str) or not isinstance(polling_url, str):
|
||||
msg = f"BFL API response missing 'id' or 'polling_url': {body}"
|
||||
raise BFLError(msg)
|
||||
|
||||
return task_id, polling_url
|
||||
|
||||
async def _poll(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
task_id: str,
|
||||
polling_url: str,
|
||||
) -> str:
|
||||
"""Poll the task until ready and return the sample image URL."""
|
||||
elapsed = 0.0
|
||||
|
||||
while elapsed < self._timeout:
|
||||
response = await client.get(polling_url)
|
||||
|
||||
if response.status_code != 200: # noqa: PLR2004
|
||||
msg = f"BFL polling returned status {response.status_code}: {response.text}"
|
||||
raise BFLError(msg)
|
||||
|
||||
body = cast(dict[str, object], response.json())
|
||||
status = body.get("status")
|
||||
|
||||
if status == "Ready":
|
||||
result = body.get("result")
|
||||
if not isinstance(result, dict):
|
||||
msg = f"BFL task {task_id} ready but no result dict: {body}"
|
||||
raise BFLError(msg)
|
||||
result_dict = cast(dict[str, object], result)
|
||||
sample = result_dict.get("sample")
|
||||
if not isinstance(sample, str):
|
||||
msg = f"BFL task {task_id} ready but no sample URL: {result}"
|
||||
raise BFLError(msg)
|
||||
return sample
|
||||
|
||||
if status in ("Error", "Failed"):
|
||||
error_msg = body.get("result", body)
|
||||
msg = f"BFL task {task_id} failed: {error_msg}"
|
||||
raise BFLError(msg)
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
elapsed += self._poll_interval
|
||||
|
||||
msg = f"BFL task {task_id} timed out after {self._timeout}s"
|
||||
raise BFLError(msg)
|
||||
Loading…
Add table
Add a link
Reference in a new issue