refactor: replace blackforest package with custom async BFL client
Some checks failed
Continuous Integration / Build Package (push) Successful in 41s
Continuous Integration / Lint, Check & Test (push) Failing after 59s

Implement bulkgen/providers/bfl.py with a fully async httpx-based client
that supports all current and future BFL models (including flux-2-*).
Remove the blackforest dependency and simplify the image provider by
eliminating the asyncio.to_thread wrapper.
This commit is contained in:
Konstantin Fickel 2026-02-14 16:44:36 +01:00
parent fd09d127f2
commit cf73511876
Signed by: kfickel
GPG key ID: A793722F9933C1A5
5 changed files with 179 additions and 89 deletions

148
bulkgen/providers/bfl.py Normal file
View file

@ -0,0 +1,148 @@
"""Async client for the BlackForestLabs image generation API."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import cast
import httpx
_BASE_URL = "https://api.bfl.ai"
_DEFAULT_POLL_INTERVAL = 1.0
_DEFAULT_TIMEOUT = 300.0
@dataclass
class BFLResult:
"""Successful generation result from the BFL API."""
task_id: str
sample_url: str
class BFLError(Exception):
"""Error returned by the BFL API."""
class BFLClient:
"""Async client for the BlackForestLabs image generation API.
Submits generation requests and polls for results.
"""
_api_key: str
_base_url: str
_poll_interval: float
_timeout: float
def __init__(
self,
api_key: str,
*,
base_url: str = _BASE_URL,
poll_interval: float = _DEFAULT_POLL_INTERVAL,
timeout: float = _DEFAULT_TIMEOUT,
) -> None:
self._api_key = api_key
self._base_url = base_url
self._poll_interval = poll_interval
self._timeout = timeout
async def generate(self, model: str, inputs: dict[str, object]) -> BFLResult:
"""Submit a generation request and poll until the result is ready.
Args:
model: BFL model name (e.g. ``flux-pro-1.1``, ``flux-2-pro``).
inputs: Model-specific parameters (prompt, dimensions, etc.).
Returns:
A :class:`BFLResult` containing the task ID and sample image URL.
Raises:
BFLError: If the API returns an error or the request times out.
"""
headers = {
"x-key": self._api_key,
"Content-Type": "application/json",
"Accept": "application/json",
}
async with httpx.AsyncClient(
base_url=self._base_url, headers=headers
) as client:
task_id, polling_url = await self._submit(client, model, inputs)
sample_url = await self._poll(client, task_id, polling_url)
return BFLResult(task_id=task_id, sample_url=sample_url)
async def _submit(
self,
client: httpx.AsyncClient,
model: str,
inputs: dict[str, object],
) -> tuple[str, str]:
"""POST the generation request and return (task_id, polling_url)."""
response = await client.post(f"/v1/{model}", json=inputs)
if response.status_code == 422: # noqa: PLR2004
body = cast(dict[str, object], response.json())
detail = body.get("detail", body)
msg = f"BFL validation error for model '{model}': {detail}"
raise BFLError(msg)
if response.status_code != 200: # noqa: PLR2004
msg = f"BFL API returned status {response.status_code}: {response.text}"
raise BFLError(msg)
body = cast(dict[str, object], response.json())
task_id = body.get("id")
polling_url = body.get("polling_url")
if not isinstance(task_id, str) or not isinstance(polling_url, str):
msg = f"BFL API response missing 'id' or 'polling_url': {body}"
raise BFLError(msg)
return task_id, polling_url
async def _poll(
self,
client: httpx.AsyncClient,
task_id: str,
polling_url: str,
) -> str:
"""Poll the task until ready and return the sample image URL."""
elapsed = 0.0
while elapsed < self._timeout:
response = await client.get(polling_url)
if response.status_code != 200: # noqa: PLR2004
msg = f"BFL polling returned status {response.status_code}: {response.text}"
raise BFLError(msg)
body = cast(dict[str, object], response.json())
status = body.get("status")
if status == "Ready":
result = body.get("result")
if not isinstance(result, dict):
msg = f"BFL task {task_id} ready but no result dict: {body}"
raise BFLError(msg)
result_dict = cast(dict[str, object], result)
sample = result_dict.get("sample")
if not isinstance(sample, str):
msg = f"BFL task {task_id} ready but no sample URL: {result}"
raise BFLError(msg)
return sample
if status in ("Error", "Failed"):
error_msg = body.get("result", body)
msg = f"BFL task {task_id} failed: {error_msg}"
raise BFLError(msg)
await asyncio.sleep(self._poll_interval)
elapsed += self._poll_interval
msg = f"BFL task {task_id} timed out after {self._timeout}s"
raise BFLError(msg)

View file

@ -2,24 +2,15 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import base64 import base64
from pathlib import Path from pathlib import Path
from typing import override from typing import override
import httpx import httpx
from blackforest import BFLClient # pyright: ignore[reportMissingTypeStubs]
from blackforest.types.general.client_config import ( # pyright: ignore[reportMissingTypeStubs]
ClientConfig,
)
from blackforest.types.responses.responses import ( # pyright: ignore[reportMissingTypeStubs]
SyncResponse,
)
from bulkgen.config import TargetConfig from bulkgen.config import TargetConfig
from bulkgen.providers import Provider from bulkgen.providers import Provider
from bulkgen.providers.bfl import BFLClient
_BFL_SYNC_CONFIG = ClientConfig(sync=True, timeout=300)
def _encode_image_b64(path: Path) -> str: def _encode_image_b64(path: Path) -> str:
@ -61,24 +52,10 @@ class ImageProvider(Provider):
ctrl_path = project_dir / control_name ctrl_path = project_dir / control_name
inputs["control_image"] = _encode_image_b64(ctrl_path) inputs["control_image"] = _encode_image_b64(ctrl_path)
result = await asyncio.to_thread( result = await self._client.generate(resolved_model, inputs)
self._client.generate, resolved_model, inputs, _BFL_SYNC_CONFIG
)
if not isinstance(result, SyncResponse):
msg = (
f"BFL API returned unexpected response type for target '{target_name}'"
)
raise RuntimeError(msg)
result_dict: dict[str, str] = result.result # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
image_url = result_dict.get("sample")
if not image_url:
msg = f"BFL API did not return an image URL for target '{target_name}'"
raise RuntimeError(msg)
async with httpx.AsyncClient() as http: async with httpx.AsyncClient() as http:
response = await http.get(image_url) response = await http.get(result.sample_url)
_ = response.raise_for_status() _ = response.raise_for_status()
_ = output_path.write_bytes(response.content) _ = output_path.write_bytes(response.content)

View file

@ -5,7 +5,7 @@ description = "Bulk-Generate Images with Generative AI"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"blackforest>=0.1.3", "httpx>=0.27.0",
"mistralai>=1.0.0", "mistralai>=1.0.0",
"networkx>=3.6.1", "networkx>=3.6.1",
"pydantic>=2.12.5", "pydantic>=2.12.5",

View file

@ -13,6 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from bulkgen.config import TargetConfig from bulkgen.config import TargetConfig
from bulkgen.providers.bfl import BFLResult
from bulkgen.providers.image import ImageProvider from bulkgen.providers.image import ImageProvider
from bulkgen.providers.image import ( from bulkgen.providers.image import (
_encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage] _encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage]
@ -22,10 +23,11 @@ from bulkgen.providers.text import TextProvider
def _make_bfl_mocks( def _make_bfl_mocks(
image_bytes: bytes, image_bytes: bytes,
) -> tuple[MagicMock, MagicMock]: ) -> tuple[BFLResult, MagicMock]:
"""Return (mock_result, mock_http) for BFL image generation tests.""" """Return (bfl_result, mock_http) for BFL image generation tests."""
mock_result = MagicMock() bfl_result = BFLResult(
mock_result.result = {"sample": "https://example.com/img.png"} task_id="test-task-id", sample_url="https://example.com/img.png"
)
mock_response = MagicMock() mock_response = MagicMock()
mock_response.content = image_bytes mock_response.content = image_bytes
@ -36,7 +38,7 @@ def _make_bfl_mocks(
mock_http.__aenter__ = AsyncMock(return_value=mock_http) mock_http.__aenter__ = AsyncMock(return_value=mock_http)
mock_http.__aexit__ = AsyncMock(return_value=False) mock_http.__aexit__ = AsyncMock(return_value=False)
return mock_result, mock_http return bfl_result, mock_http
def _make_mistral_mock(response: MagicMock) -> AsyncMock: def _make_mistral_mock(response: MagicMock) -> AsyncMock:
@ -68,14 +70,13 @@ class TestImageProvider:
self, project_dir: Path, image_bytes: bytes self, project_dir: Path, image_bytes: bytes
) -> None: ) -> None:
target_config = TargetConfig(prompt="A red square") target_config = TargetConfig(prompt="A red square")
mock_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
patch("bulkgen.providers.image.BFLClient") as mock_cls, patch("bulkgen.providers.image.BFLClient") as mock_cls,
patch("bulkgen.providers.image.isinstance", return_value=True),
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
): ):
mock_cls.return_value.generate.return_value = mock_result mock_cls.return_value.generate = AsyncMock(return_value=bfl_result)
mock_http_cls.return_value = mock_http mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key") provider = ImageProvider(api_key="test-key")
@ -95,15 +96,14 @@ class TestImageProvider:
self, project_dir: Path, image_bytes: bytes self, project_dir: Path, image_bytes: bytes
) -> None: ) -> None:
target_config = TargetConfig(prompt="A banner", width=1920, height=480) target_config = TargetConfig(prompt="A banner", width=1920, height=480)
mock_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
patch("bulkgen.providers.image.BFLClient") as mock_cls, patch("bulkgen.providers.image.BFLClient") as mock_cls,
patch("bulkgen.providers.image.isinstance", return_value=True),
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
): ):
client_instance = mock_cls.return_value mock_generate = AsyncMock(return_value=bfl_result)
client_instance.generate.return_value = mock_result mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key") provider = ImageProvider(api_key="test-key")
@ -115,7 +115,7 @@ class TestImageProvider:
project_dir=project_dir, project_dir=project_dir,
) )
call_args = client_instance.generate.call_args call_args = mock_generate.call_args
inputs = call_args[0][1] inputs = call_args[0][1]
assert inputs["width"] == 1920 assert inputs["width"] == 1920
assert inputs["height"] == 480 assert inputs["height"] == 480
@ -127,15 +127,14 @@ class TestImageProvider:
_ = ref_path.write_bytes(b"reference image data") _ = ref_path.write_bytes(b"reference image data")
target_config = TargetConfig(prompt="Like this", reference_image="ref.png") target_config = TargetConfig(prompt="Like this", reference_image="ref.png")
mock_result, mock_http = _make_bfl_mocks(image_bytes) bfl_result, mock_http = _make_bfl_mocks(image_bytes)
with ( with (
patch("bulkgen.providers.image.BFLClient") as mock_cls, patch("bulkgen.providers.image.BFLClient") as mock_cls,
patch("bulkgen.providers.image.isinstance", return_value=True),
patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls,
): ):
client_instance = mock_cls.return_value mock_generate = AsyncMock(return_value=bfl_result)
client_instance.generate.return_value = mock_result mock_cls.return_value.generate = mock_generate
mock_http_cls.return_value = mock_http mock_http_cls.return_value = mock_http
provider = ImageProvider(api_key="test-key") provider = ImageProvider(api_key="test-key")
@ -147,29 +146,28 @@ class TestImageProvider:
project_dir=project_dir, project_dir=project_dir,
) )
call_args = client_instance.generate.call_args call_args = mock_generate.call_args
inputs = call_args[0][1] inputs = call_args[0][1]
assert "image_prompt" in inputs assert "image_prompt" in inputs
assert inputs["image_prompt"] == encode_image_b64(ref_path) assert inputs["image_prompt"] == encode_image_b64(ref_path)
async def test_image_no_sample_url_raises(self, project_dir: Path) -> None: async def test_image_no_sample_url_raises(self, project_dir: Path) -> None:
target_config = TargetConfig(prompt="x") target_config = TargetConfig(prompt="x")
mock_result = MagicMock()
mock_result.result = {}
with ( with patch("bulkgen.providers.image.BFLClient") as mock_cls:
patch("bulkgen.providers.image.BFLClient") as mock_cls, from bulkgen.providers.bfl import BFLError
patch("bulkgen.providers.image.isinstance", return_value=True),
): mock_cls.return_value.generate = AsyncMock(
mock_cls.return_value.generate.return_value = mock_result side_effect=BFLError("BFL task test ready but no sample URL: {}")
)
provider = ImageProvider(api_key="test-key") provider = ImageProvider(api_key="test-key")
with pytest.raises(RuntimeError, match="did not return an image URL"): with pytest.raises(BFLError, match="no sample URL"):
await provider.generate( await provider.generate(
target_name="fail.png", target_name="fail.png",
target_config=target_config, target_config=target_config,
resolved_prompt="x", resolved_prompt="x",
resolved_model="flux-pro", resolved_model="flux-pro-1.1",
project_dir=project_dir, project_dir=project_dir,
) )

37
uv.lock generated
View file

@ -44,26 +44,12 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d5/90/1883cec16d667d944b08e8d8909b9b2f46cc1d2b9731e855e3c71f9b0450/basedpyright-1.38.0-py3-none-any.whl", hash = "sha256:a6c11a343fd12a2152a0d721b0e92f54f2e2e3322ee2562197e27dad952f1a61", size = 12303557, upload-time = "2026-02-11T16:05:44.863Z" }, { url = "https://files.pythonhosted.org/packages/d5/90/1883cec16d667d944b08e8d8909b9b2f46cc1d2b9731e855e3c71f9b0450/basedpyright-1.38.0-py3-none-any.whl", hash = "sha256:a6c11a343fd12a2152a0d721b0e92f54f2e2e3322ee2562197e27dad952f1a61", size = 12303557, upload-time = "2026-02-11T16:05:44.863Z" },
] ]
[[package]]
name = "blackforest"
version = "0.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pillow" },
{ name = "pydantic" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b4/e5/e580c9ae37bcf2d70365c452d3e6a34e5f9cf0ced4fcb63430d46530a908/blackforest-0.1.3.tar.gz", hash = "sha256:8e5b069690a036fda90c1f3b8b44189a873522d54852cbea5642f2bcf4df3158", size = 18834, upload-time = "2025-09-10T10:02:12.79Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8b/ca/02cbaab5a7979fa39547269bbb9bf434013c68115ae633d948697514ef4d/blackforest-0.1.3-py3-none-any.whl", hash = "sha256:8f0fed51e305ed93711e756971bf7b85569f03714504d55e154bbb0c2cf6cfb1", size = 22868, upload-time = "2025-09-10T10:02:11.6Z" },
]
[[package]] [[package]]
name = "bulkgen" name = "bulkgen"
version = "0.1.0" version = "0.1.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "blackforest" }, { name = "httpx" },
{ name = "mistralai" }, { name = "mistralai" },
{ name = "networkx" }, { name = "networkx" },
{ name = "pydantic" }, { name = "pydantic" },
@ -81,7 +67,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "blackforest", specifier = ">=0.1.3" }, { name = "httpx", specifier = ">=0.27.0" },
{ name = "mistralai", specifier = ">=1.0.0" }, { name = "mistralai", specifier = ">=1.0.0" },
{ name = "networkx", specifier = ">=3.6.1" }, { name = "networkx", specifier = ">=3.6.1" },
{ name = "pydantic", specifier = ">=2.12.5" }, { name = "pydantic", specifier = ">=2.12.5" },
@ -423,25 +409,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
] ]
[[package]]
name = "pillow"
version = "10.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/cd/74/ad3d526f3bf7b6d3f408b73fde271ec69dfac8b81341a318ce825f2b3812/pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06", size = 46555059, upload-time = "2024-07-01T09:48:43.583Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/00/706cebe7c2c12a6318aabe5d354836f54adff7156fd9e1bd6c89f4ba0e98/pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3", size = 3525685, upload-time = "2024-07-01T09:46:45.194Z" },
{ url = "https://files.pythonhosted.org/packages/cf/76/f658cbfa49405e5ecbfb9ba42d07074ad9792031267e782d409fd8fe7c69/pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb", size = 3374883, upload-time = "2024-07-01T09:46:47.331Z" },
{ url = "https://files.pythonhosted.org/packages/46/2b/99c28c4379a85e65378211971c0b430d9c7234b1ec4d59b2668f6299e011/pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70", size = 4339837, upload-time = "2024-07-01T09:46:49.647Z" },
{ url = "https://files.pythonhosted.org/packages/f1/74/b1ec314f624c0c43711fdf0d8076f82d9d802afd58f1d62c2a86878e8615/pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be", size = 4455562, upload-time = "2024-07-01T09:46:51.811Z" },
{ url = "https://files.pythonhosted.org/packages/4a/2a/4b04157cb7b9c74372fa867096a1607e6fedad93a44deeff553ccd307868/pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0", size = 4366761, upload-time = "2024-07-01T09:46:53.961Z" },
{ url = "https://files.pythonhosted.org/packages/ac/7b/8f1d815c1a6a268fe90481232c98dd0e5fa8c75e341a75f060037bd5ceae/pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc", size = 4536767, upload-time = "2024-07-01T09:46:56.664Z" },
{ url = "https://files.pythonhosted.org/packages/e5/77/05fa64d1f45d12c22c314e7b97398ffb28ef2813a485465017b7978b3ce7/pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a", size = 4477989, upload-time = "2024-07-01T09:46:58.977Z" },
{ url = "https://files.pythonhosted.org/packages/12/63/b0397cfc2caae05c3fb2f4ed1b4fc4fc878f0243510a7a6034ca59726494/pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309", size = 4610255, upload-time = "2024-07-01T09:47:01.189Z" },
{ url = "https://files.pythonhosted.org/packages/7b/f9/cfaa5082ca9bc4a6de66ffe1c12c2d90bf09c309a5f52b27759a596900e7/pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060", size = 2235603, upload-time = "2024-07-01T09:47:03.918Z" },
{ url = "https://files.pythonhosted.org/packages/01/6a/30ff0eef6e0c0e71e55ded56a38d4859bf9d3634a94a88743897b5f96936/pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea", size = 2554972, upload-time = "2024-07-01T09:47:06.152Z" },
{ url = "https://files.pythonhosted.org/packages/48/2c/2e0a52890f269435eee38b21c8218e102c621fe8d8df8b9dd06fabf879ba/pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d", size = 2243375, upload-time = "2024-07-01T09:47:09.065Z" },
]
[[package]] [[package]]
name = "pluggy" name = "pluggy"
version = "1.6.0" version = "1.6.0"