hokusai/bulkgen/providers/image.py

89 lines
2.8 KiB
Python

"""BlackForestLabs image generation provider."""
from __future__ import annotations
import base64
from pathlib import Path
from typing import override
import httpx
from bulkgen.config import TargetConfig
from bulkgen.providers import Provider
from bulkgen.providers.bfl import BFLClient
from bulkgen.providers.models import ModelInfo
def _encode_image_b64(path: Path) -> str:
"""Read an image file and return its base64-encoded representation."""
return base64.b64encode(path.read_bytes()).decode("ascii")
# Parameter names for reference images, keyed by model prefix.
_INPUT_IMAGE_KEYS = ["input_image"] + [f"input_image_{i}" for i in range(2, 9)]
_IMAGE_PROMPT_KEYS = ["image_prompt"]
def _ref_image_keys(model: str) -> list[str]:
"""Return the ordered API parameter names for reference images."""
if model.startswith("flux-2-"):
return _INPUT_IMAGE_KEYS # up to 8
if model.startswith("flux-kontext-"):
return _INPUT_IMAGE_KEYS[:4] # up to 4
return _IMAGE_PROMPT_KEYS # flux 1.x: single image_prompt
def _add_reference_images(
inputs: dict[str, object],
reference_images: list[str],
model: str,
project_dir: Path,
) -> None:
"""Encode reference images and add them under the correct API keys."""
keys = _ref_image_keys(model)
for key, ref_name in zip(keys, reference_images, strict=False):
inputs[key] = _encode_image_b64(project_dir / ref_name)
class ImageProvider(Provider):
"""Generates images via the BlackForestLabs API."""
_client: BFLClient
def __init__(self, api_key: str) -> None:
self._client = BFLClient(api_key=api_key)
@override
async def generate(
self,
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
output_path = project_dir / target_name
inputs: dict[str, object] = {"prompt": resolved_prompt}
if target_config.width is not None:
inputs["width"] = target_config.width
if target_config.height is not None:
inputs["height"] = target_config.height
if target_config.reference_images:
_add_reference_images(
inputs, target_config.reference_images, resolved_model.name, project_dir
)
for control_name in target_config.control_images:
ctrl_path = project_dir / control_name
inputs["control_image"] = _encode_image_b64(ctrl_path)
result = await self._client.generate(resolved_model.name, inputs)
async with httpx.AsyncClient() as http:
response = await http.get(result.sample_url)
_ = response.raise_for_status()
_ = output_path.write_bytes(response.content)