refactor: replace blackforest package with custom async BFL client
Some checks failed
Continuous Integration / Build Package (push) Successful in 41s
Continuous Integration / Lint, Check & Test (push) Failing after 59s

Implement bulkgen/providers/bfl.py with a fully async httpx-based client
that supports all current and future BFL models (including flux-2-*).
Remove the blackforest dependency and simplify the image provider by
eliminating the asyncio.to_thread wrapper.
This commit is contained in:
Konstantin Fickel 2026-02-14 16:44:36 +01:00
parent fd09d127f2
commit cf73511876
Signed by: kfickel
GPG key ID: A793722F9933C1A5
5 changed files with 179 additions and 89 deletions

148
bulkgen/providers/bfl.py Normal file
View file

@ -0,0 +1,148 @@
"""Async client for the BlackForestLabs image generation API."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import cast
import httpx
_BASE_URL = "https://api.bfl.ai"
_DEFAULT_POLL_INTERVAL = 1.0
_DEFAULT_TIMEOUT = 300.0
@dataclass
class BFLResult:
"""Successful generation result from the BFL API."""
task_id: str
sample_url: str
class BFLError(Exception):
"""Error returned by the BFL API."""
class BFLClient:
"""Async client for the BlackForestLabs image generation API.
Submits generation requests and polls for results.
"""
_api_key: str
_base_url: str
_poll_interval: float
_timeout: float
def __init__(
self,
api_key: str,
*,
base_url: str = _BASE_URL,
poll_interval: float = _DEFAULT_POLL_INTERVAL,
timeout: float = _DEFAULT_TIMEOUT,
) -> None:
self._api_key = api_key
self._base_url = base_url
self._poll_interval = poll_interval
self._timeout = timeout
async def generate(self, model: str, inputs: dict[str, object]) -> BFLResult:
"""Submit a generation request and poll until the result is ready.
Args:
model: BFL model name (e.g. ``flux-pro-1.1``, ``flux-2-pro``).
inputs: Model-specific parameters (prompt, dimensions, etc.).
Returns:
A :class:`BFLResult` containing the task ID and sample image URL.
Raises:
BFLError: If the API returns an error or the request times out.
"""
headers = {
"x-key": self._api_key,
"Content-Type": "application/json",
"Accept": "application/json",
}
async with httpx.AsyncClient(
base_url=self._base_url, headers=headers
) as client:
task_id, polling_url = await self._submit(client, model, inputs)
sample_url = await self._poll(client, task_id, polling_url)
return BFLResult(task_id=task_id, sample_url=sample_url)
async def _submit(
self,
client: httpx.AsyncClient,
model: str,
inputs: dict[str, object],
) -> tuple[str, str]:
"""POST the generation request and return (task_id, polling_url)."""
response = await client.post(f"/v1/{model}", json=inputs)
if response.status_code == 422: # noqa: PLR2004
body = cast(dict[str, object], response.json())
detail = body.get("detail", body)
msg = f"BFL validation error for model '{model}': {detail}"
raise BFLError(msg)
if response.status_code != 200: # noqa: PLR2004
msg = f"BFL API returned status {response.status_code}: {response.text}"
raise BFLError(msg)
body = cast(dict[str, object], response.json())
task_id = body.get("id")
polling_url = body.get("polling_url")
if not isinstance(task_id, str) or not isinstance(polling_url, str):
msg = f"BFL API response missing 'id' or 'polling_url': {body}"
raise BFLError(msg)
return task_id, polling_url
async def _poll(
self,
client: httpx.AsyncClient,
task_id: str,
polling_url: str,
) -> str:
"""Poll the task until ready and return the sample image URL."""
elapsed = 0.0
while elapsed < self._timeout:
response = await client.get(polling_url)
if response.status_code != 200: # noqa: PLR2004
msg = f"BFL polling returned status {response.status_code}: {response.text}"
raise BFLError(msg)
body = cast(dict[str, object], response.json())
status = body.get("status")
if status == "Ready":
result = body.get("result")
if not isinstance(result, dict):
msg = f"BFL task {task_id} ready but no result dict: {body}"
raise BFLError(msg)
result_dict = cast(dict[str, object], result)
sample = result_dict.get("sample")
if not isinstance(sample, str):
msg = f"BFL task {task_id} ready but no sample URL: {result}"
raise BFLError(msg)
return sample
if status in ("Error", "Failed"):
error_msg = body.get("result", body)
msg = f"BFL task {task_id} failed: {error_msg}"
raise BFLError(msg)
await asyncio.sleep(self._poll_interval)
elapsed += self._poll_interval
msg = f"BFL task {task_id} timed out after {self._timeout}s"
raise BFLError(msg)

View file

@ -2,24 +2,15 @@
from __future__ import annotations
import asyncio
import base64
from pathlib import Path
from typing import override
import httpx
from blackforest import BFLClient # pyright: ignore[reportMissingTypeStubs]
from blackforest.types.general.client_config import ( # pyright: ignore[reportMissingTypeStubs]
ClientConfig,
)
from blackforest.types.responses.responses import ( # pyright: ignore[reportMissingTypeStubs]
SyncResponse,
)
from bulkgen.config import TargetConfig
from bulkgen.providers import Provider
_BFL_SYNC_CONFIG = ClientConfig(sync=True, timeout=300)
from bulkgen.providers.bfl import BFLClient
def _encode_image_b64(path: Path) -> str:
@ -61,24 +52,10 @@ class ImageProvider(Provider):
ctrl_path = project_dir / control_name
inputs["control_image"] = _encode_image_b64(ctrl_path)
result = await asyncio.to_thread(
self._client.generate, resolved_model, inputs, _BFL_SYNC_CONFIG
)
if not isinstance(result, SyncResponse):
msg = (
f"BFL API returned unexpected response type for target '{target_name}'"
)
raise RuntimeError(msg)
result_dict: dict[str, str] = result.result # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
image_url = result_dict.get("sample")
if not image_url:
msg = f"BFL API did not return an image URL for target '{target_name}'"
raise RuntimeError(msg)
result = await self._client.generate(resolved_model, inputs)
async with httpx.AsyncClient() as http:
response = await http.get(image_url)
response = await http.get(result.sample_url)
_ = response.raise_for_status()
_ = output_path.write_bytes(response.content)