From cf73511876120f97efd1e2e6afe19c7ab182f7a6 Mon Sep 17 00:00:00 2001 From: Konstantin Fickel Date: Sat, 14 Feb 2026 16:44:36 +0100 Subject: [PATCH] refactor: replace blackforest package with custom async BFL client Implement bulkgen/providers/bfl.py with a fully async httpx-based client that supports all current and future BFL models (including flux-2-*). Remove the blackforest dependency and simplify the image provider by eliminating the asyncio.to_thread wrapper. --- bulkgen/providers/bfl.py | 148 +++++++++++++++++++++++++++++++++++++ bulkgen/providers/image.py | 29 +------- pyproject.toml | 2 +- tests/test_providers.py | 52 +++++++------ uv.lock | 37 +--------- 5 files changed, 179 insertions(+), 89 deletions(-) create mode 100644 bulkgen/providers/bfl.py diff --git a/bulkgen/providers/bfl.py b/bulkgen/providers/bfl.py new file mode 100644 index 0000000..1ccbb83 --- /dev/null +++ b/bulkgen/providers/bfl.py @@ -0,0 +1,148 @@ +"""Async client for the BlackForestLabs image generation API.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import cast + +import httpx + +_BASE_URL = "https://api.bfl.ai" +_DEFAULT_POLL_INTERVAL = 1.0 +_DEFAULT_TIMEOUT = 300.0 + + +@dataclass +class BFLResult: + """Successful generation result from the BFL API.""" + + task_id: str + sample_url: str + + +class BFLError(Exception): + """Error returned by the BFL API.""" + + +class BFLClient: + """Async client for the BlackForestLabs image generation API. + + Submits generation requests and polls for results. + """ + + _api_key: str + _base_url: str + _poll_interval: float + _timeout: float + + def __init__( + self, + api_key: str, + *, + base_url: str = _BASE_URL, + poll_interval: float = _DEFAULT_POLL_INTERVAL, + timeout: float = _DEFAULT_TIMEOUT, + ) -> None: + self._api_key = api_key + self._base_url = base_url + self._poll_interval = poll_interval + self._timeout = timeout + + async def generate(self, model: str, inputs: dict[str, object]) -> BFLResult: + """Submit a generation request and poll until the result is ready. + + Args: + model: BFL model name (e.g. ``flux-pro-1.1``, ``flux-2-pro``). + inputs: Model-specific parameters (prompt, dimensions, etc.). + + Returns: + A :class:`BFLResult` containing the task ID and sample image URL. + + Raises: + BFLError: If the API returns an error or the request times out. + """ + headers = { + "x-key": self._api_key, + "Content-Type": "application/json", + "Accept": "application/json", + } + + async with httpx.AsyncClient( + base_url=self._base_url, headers=headers + ) as client: + task_id, polling_url = await self._submit(client, model, inputs) + sample_url = await self._poll(client, task_id, polling_url) + + return BFLResult(task_id=task_id, sample_url=sample_url) + + async def _submit( + self, + client: httpx.AsyncClient, + model: str, + inputs: dict[str, object], + ) -> tuple[str, str]: + """POST the generation request and return (task_id, polling_url).""" + response = await client.post(f"/v1/{model}", json=inputs) + + if response.status_code == 422: # noqa: PLR2004 + body = cast(dict[str, object], response.json()) + detail = body.get("detail", body) + msg = f"BFL validation error for model '{model}': {detail}" + raise BFLError(msg) + + if response.status_code != 200: # noqa: PLR2004 + msg = f"BFL API returned status {response.status_code}: {response.text}" + raise BFLError(msg) + + body = cast(dict[str, object], response.json()) + task_id = body.get("id") + polling_url = body.get("polling_url") + + if not isinstance(task_id, str) or not isinstance(polling_url, str): + msg = f"BFL API response missing 'id' or 'polling_url': {body}" + raise BFLError(msg) + + return task_id, polling_url + + async def _poll( + self, + client: httpx.AsyncClient, + task_id: str, + polling_url: str, + ) -> str: + """Poll the task until ready and return the sample image URL.""" + elapsed = 0.0 + + while elapsed < self._timeout: + response = await client.get(polling_url) + + if response.status_code != 200: # noqa: PLR2004 + msg = f"BFL polling returned status {response.status_code}: {response.text}" + raise BFLError(msg) + + body = cast(dict[str, object], response.json()) + status = body.get("status") + + if status == "Ready": + result = body.get("result") + if not isinstance(result, dict): + msg = f"BFL task {task_id} ready but no result dict: {body}" + raise BFLError(msg) + result_dict = cast(dict[str, object], result) + sample = result_dict.get("sample") + if not isinstance(sample, str): + msg = f"BFL task {task_id} ready but no sample URL: {result}" + raise BFLError(msg) + return sample + + if status in ("Error", "Failed"): + error_msg = body.get("result", body) + msg = f"BFL task {task_id} failed: {error_msg}" + raise BFLError(msg) + + await asyncio.sleep(self._poll_interval) + elapsed += self._poll_interval + + msg = f"BFL task {task_id} timed out after {self._timeout}s" + raise BFLError(msg) diff --git a/bulkgen/providers/image.py b/bulkgen/providers/image.py index b8fa7fe..79ccab1 100644 --- a/bulkgen/providers/image.py +++ b/bulkgen/providers/image.py @@ -2,24 +2,15 @@ from __future__ import annotations -import asyncio import base64 from pathlib import Path from typing import override import httpx -from blackforest import BFLClient # pyright: ignore[reportMissingTypeStubs] -from blackforest.types.general.client_config import ( # pyright: ignore[reportMissingTypeStubs] - ClientConfig, -) -from blackforest.types.responses.responses import ( # pyright: ignore[reportMissingTypeStubs] - SyncResponse, -) from bulkgen.config import TargetConfig from bulkgen.providers import Provider - -_BFL_SYNC_CONFIG = ClientConfig(sync=True, timeout=300) +from bulkgen.providers.bfl import BFLClient def _encode_image_b64(path: Path) -> str: @@ -61,24 +52,10 @@ class ImageProvider(Provider): ctrl_path = project_dir / control_name inputs["control_image"] = _encode_image_b64(ctrl_path) - result = await asyncio.to_thread( - self._client.generate, resolved_model, inputs, _BFL_SYNC_CONFIG - ) - - if not isinstance(result, SyncResponse): - msg = ( - f"BFL API returned unexpected response type for target '{target_name}'" - ) - raise RuntimeError(msg) - - result_dict: dict[str, str] = result.result # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - image_url = result_dict.get("sample") - if not image_url: - msg = f"BFL API did not return an image URL for target '{target_name}'" - raise RuntimeError(msg) + result = await self._client.generate(resolved_model, inputs) async with httpx.AsyncClient() as http: - response = await http.get(image_url) + response = await http.get(result.sample_url) _ = response.raise_for_status() _ = output_path.write_bytes(response.content) diff --git a/pyproject.toml b/pyproject.toml index aa92e7c..7aa84b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Bulk-Generate Images with Generative AI" readme = "README.md" requires-python = ">=3.13" dependencies = [ - "blackforest>=0.1.3", + "httpx>=0.27.0", "mistralai>=1.0.0", "networkx>=3.6.1", "pydantic>=2.12.5", diff --git a/tests/test_providers.py b/tests/test_providers.py index 4baf9c5..cb855c0 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -13,6 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from bulkgen.config import TargetConfig +from bulkgen.providers.bfl import BFLResult from bulkgen.providers.image import ImageProvider from bulkgen.providers.image import ( _encode_image_b64 as encode_image_b64, # pyright: ignore[reportPrivateUsage] @@ -22,10 +23,11 @@ from bulkgen.providers.text import TextProvider def _make_bfl_mocks( image_bytes: bytes, -) -> tuple[MagicMock, MagicMock]: - """Return (mock_result, mock_http) for BFL image generation tests.""" - mock_result = MagicMock() - mock_result.result = {"sample": "https://example.com/img.png"} +) -> tuple[BFLResult, MagicMock]: + """Return (bfl_result, mock_http) for BFL image generation tests.""" + bfl_result = BFLResult( + task_id="test-task-id", sample_url="https://example.com/img.png" + ) mock_response = MagicMock() mock_response.content = image_bytes @@ -36,7 +38,7 @@ def _make_bfl_mocks( mock_http.__aenter__ = AsyncMock(return_value=mock_http) mock_http.__aexit__ = AsyncMock(return_value=False) - return mock_result, mock_http + return bfl_result, mock_http def _make_mistral_mock(response: MagicMock) -> AsyncMock: @@ -68,14 +70,13 @@ class TestImageProvider: self, project_dir: Path, image_bytes: bytes ) -> None: target_config = TargetConfig(prompt="A red square") - mock_result, mock_http = _make_bfl_mocks(image_bytes) + bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.isinstance", return_value=True), patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, ): - mock_cls.return_value.generate.return_value = mock_result + mock_cls.return_value.generate = AsyncMock(return_value=bfl_result) mock_http_cls.return_value = mock_http provider = ImageProvider(api_key="test-key") @@ -95,15 +96,14 @@ class TestImageProvider: self, project_dir: Path, image_bytes: bytes ) -> None: target_config = TargetConfig(prompt="A banner", width=1920, height=480) - mock_result, mock_http = _make_bfl_mocks(image_bytes) + bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.isinstance", return_value=True), patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, ): - client_instance = mock_cls.return_value - client_instance.generate.return_value = mock_result + mock_generate = AsyncMock(return_value=bfl_result) + mock_cls.return_value.generate = mock_generate mock_http_cls.return_value = mock_http provider = ImageProvider(api_key="test-key") @@ -115,7 +115,7 @@ class TestImageProvider: project_dir=project_dir, ) - call_args = client_instance.generate.call_args + call_args = mock_generate.call_args inputs = call_args[0][1] assert inputs["width"] == 1920 assert inputs["height"] == 480 @@ -127,15 +127,14 @@ class TestImageProvider: _ = ref_path.write_bytes(b"reference image data") target_config = TargetConfig(prompt="Like this", reference_image="ref.png") - mock_result, mock_http = _make_bfl_mocks(image_bytes) + bfl_result, mock_http = _make_bfl_mocks(image_bytes) with ( patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.isinstance", return_value=True), patch("bulkgen.providers.image.httpx.AsyncClient") as mock_http_cls, ): - client_instance = mock_cls.return_value - client_instance.generate.return_value = mock_result + mock_generate = AsyncMock(return_value=bfl_result) + mock_cls.return_value.generate = mock_generate mock_http_cls.return_value = mock_http provider = ImageProvider(api_key="test-key") @@ -147,29 +146,28 @@ class TestImageProvider: project_dir=project_dir, ) - call_args = client_instance.generate.call_args + call_args = mock_generate.call_args inputs = call_args[0][1] assert "image_prompt" in inputs assert inputs["image_prompt"] == encode_image_b64(ref_path) async def test_image_no_sample_url_raises(self, project_dir: Path) -> None: target_config = TargetConfig(prompt="x") - mock_result = MagicMock() - mock_result.result = {} - with ( - patch("bulkgen.providers.image.BFLClient") as mock_cls, - patch("bulkgen.providers.image.isinstance", return_value=True), - ): - mock_cls.return_value.generate.return_value = mock_result + with patch("bulkgen.providers.image.BFLClient") as mock_cls: + from bulkgen.providers.bfl import BFLError + + mock_cls.return_value.generate = AsyncMock( + side_effect=BFLError("BFL task test ready but no sample URL: {}") + ) provider = ImageProvider(api_key="test-key") - with pytest.raises(RuntimeError, match="did not return an image URL"): + with pytest.raises(BFLError, match="no sample URL"): await provider.generate( target_name="fail.png", target_config=target_config, resolved_prompt="x", - resolved_model="flux-pro", + resolved_model="flux-pro-1.1", project_dir=project_dir, ) diff --git a/uv.lock b/uv.lock index 98ff6fa..e65d1ff 100644 --- a/uv.lock +++ b/uv.lock @@ -44,26 +44,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/90/1883cec16d667d944b08e8d8909b9b2f46cc1d2b9731e855e3c71f9b0450/basedpyright-1.38.0-py3-none-any.whl", hash = "sha256:a6c11a343fd12a2152a0d721b0e92f54f2e2e3322ee2562197e27dad952f1a61", size = 12303557, upload-time = "2026-02-11T16:05:44.863Z" }, ] -[[package]] -name = "blackforest" -version = "0.1.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pillow" }, - { name = "pydantic" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b4/e5/e580c9ae37bcf2d70365c452d3e6a34e5f9cf0ced4fcb63430d46530a908/blackforest-0.1.3.tar.gz", hash = "sha256:8e5b069690a036fda90c1f3b8b44189a873522d54852cbea5642f2bcf4df3158", size = 18834, upload-time = "2025-09-10T10:02:12.79Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/ca/02cbaab5a7979fa39547269bbb9bf434013c68115ae633d948697514ef4d/blackforest-0.1.3-py3-none-any.whl", hash = "sha256:8f0fed51e305ed93711e756971bf7b85569f03714504d55e154bbb0c2cf6cfb1", size = 22868, upload-time = "2025-09-10T10:02:11.6Z" }, -] - [[package]] name = "bulkgen" version = "0.1.0" source = { editable = "." } dependencies = [ - { name = "blackforest" }, + { name = "httpx" }, { name = "mistralai" }, { name = "networkx" }, { name = "pydantic" }, @@ -81,7 +67,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "blackforest", specifier = ">=0.1.3" }, + { name = "httpx", specifier = ">=0.27.0" }, { name = "mistralai", specifier = ">=1.0.0" }, { name = "networkx", specifier = ">=3.6.1" }, { name = "pydantic", specifier = ">=2.12.5" }, @@ -423,25 +409,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, ] -[[package]] -name = "pillow" -version = "10.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cd/74/ad3d526f3bf7b6d3f408b73fde271ec69dfac8b81341a318ce825f2b3812/pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06", size = 46555059, upload-time = "2024-07-01T09:48:43.583Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/00/706cebe7c2c12a6318aabe5d354836f54adff7156fd9e1bd6c89f4ba0e98/pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3", size = 3525685, upload-time = "2024-07-01T09:46:45.194Z" }, - { url = "https://files.pythonhosted.org/packages/cf/76/f658cbfa49405e5ecbfb9ba42d07074ad9792031267e782d409fd8fe7c69/pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb", size = 3374883, upload-time = "2024-07-01T09:46:47.331Z" }, - { url = "https://files.pythonhosted.org/packages/46/2b/99c28c4379a85e65378211971c0b430d9c7234b1ec4d59b2668f6299e011/pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70", size = 4339837, upload-time = "2024-07-01T09:46:49.647Z" }, - { url = "https://files.pythonhosted.org/packages/f1/74/b1ec314f624c0c43711fdf0d8076f82d9d802afd58f1d62c2a86878e8615/pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be", size = 4455562, upload-time = "2024-07-01T09:46:51.811Z" }, - { url = "https://files.pythonhosted.org/packages/4a/2a/4b04157cb7b9c74372fa867096a1607e6fedad93a44deeff553ccd307868/pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0", size = 4366761, upload-time = "2024-07-01T09:46:53.961Z" }, - { url = "https://files.pythonhosted.org/packages/ac/7b/8f1d815c1a6a268fe90481232c98dd0e5fa8c75e341a75f060037bd5ceae/pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc", size = 4536767, upload-time = "2024-07-01T09:46:56.664Z" }, - { url = "https://files.pythonhosted.org/packages/e5/77/05fa64d1f45d12c22c314e7b97398ffb28ef2813a485465017b7978b3ce7/pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a", size = 4477989, upload-time = "2024-07-01T09:46:58.977Z" }, - { url = "https://files.pythonhosted.org/packages/12/63/b0397cfc2caae05c3fb2f4ed1b4fc4fc878f0243510a7a6034ca59726494/pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309", size = 4610255, upload-time = "2024-07-01T09:47:01.189Z" }, - { url = "https://files.pythonhosted.org/packages/7b/f9/cfaa5082ca9bc4a6de66ffe1c12c2d90bf09c309a5f52b27759a596900e7/pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060", size = 2235603, upload-time = "2024-07-01T09:47:03.918Z" }, - { url = "https://files.pythonhosted.org/packages/01/6a/30ff0eef6e0c0e71e55ded56a38d4859bf9d3634a94a88743897b5f96936/pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea", size = 2554972, upload-time = "2024-07-01T09:47:06.152Z" }, - { url = "https://files.pythonhosted.org/packages/48/2c/2e0a52890f269435eee38b21c8218e102c621fe8d8df8b9dd06fabf879ba/pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d", size = 2243375, upload-time = "2024-07-01T09:47:09.065Z" }, -] - [[package]] name = "pluggy" version = "1.6.0"