"""OpenAI image generation provider.""" from __future__ import annotations import base64 from pathlib import Path from typing import Literal, override import httpx from openai import AsyncOpenAI from openai.types.images_response import ImagesResponse from bulkgen.config import TargetConfig from bulkgen.providers import Provider from bulkgen.providers.models import Capability, ModelInfo _SIZE = Literal[ "auto", "1024x1024", "1024x1536", "1536x1024", "1024x1792", "1792x1024", "256x256", "512x512", ] _VALID_SIZES: frozenset[str] = frozenset( { "auto", "1024x1024", "1024x1536", "1536x1024", "1024x1792", "1792x1024", "256x256", "512x512", } ) def _build_size(width: int | None, height: int | None) -> _SIZE | None: """Convert width/height to an OpenAI size string, or *None* for the default.""" if width is None and height is None: return None w = width or 1024 h = height or 1024 size = f"{w}x{h}" if size not in _VALID_SIZES: msg = f"Unsupported OpenAI image size '{size}'. Valid sizes: {', '.join(sorted(_VALID_SIZES))}" raise ValueError(msg) return size # pyright: ignore[reportReturnType] class OpenAIImageProvider(Provider): """Generates images via the OpenAI API.""" _api_key: str def __init__(self, api_key: str) -> None: self._api_key = api_key @staticmethod @override def get_provided_models() -> list[ModelInfo]: return [ ModelInfo( name="gpt-image-1", provider="OpenAI", type="image", capabilities=[ Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES, ], ), ModelInfo( name="dall-e-3", provider="OpenAI", type="image", capabilities=[Capability.TEXT_TO_IMAGE], ), ModelInfo( name="dall-e-2", provider="OpenAI", type="image", capabilities=[Capability.TEXT_TO_IMAGE], ), ] @override async def generate( self, target_name: str, target_config: TargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, ) -> None: output_path = project_dir / target_name size = _build_size(target_config.width, target_config.height) async with AsyncOpenAI(api_key=self._api_key) as client: if target_config.reference_images: response = await _generate_edit( client, resolved_prompt, resolved_model.name, target_config.reference_images, project_dir, size, ) else: response = await _generate_new( client, resolved_prompt, resolved_model.name, size, ) image_data = _extract_image_bytes(response, resolved_model.name) _ = output_path.write_bytes(image_data) async def _generate_new( client: AsyncOpenAI, prompt: str, model: str, size: _SIZE | None, ) -> ImagesResponse: """Generate a new image from a text prompt.""" if size is not None: return await client.images.generate( prompt=prompt, model=model, n=1, response_format="b64_json", size=size, ) return await client.images.generate( prompt=prompt, model=model, n=1, response_format="b64_json", ) async def _generate_edit( client: AsyncOpenAI, prompt: str, model: str, reference_images: list[str], project_dir: Path, size: _SIZE | None, ) -> ImagesResponse: """Generate an image using a reference image via the edits endpoint.""" ref_path = project_dir / reference_images[0] image_bytes = ref_path.read_bytes() if size is not None: return await client.images.edit( image=image_bytes, prompt=prompt, model=model, n=1, response_format="b64_json", size=size, # pyright: ignore[reportArgumentType] ) return await client.images.edit( image=image_bytes, prompt=prompt, model=model, n=1, response_format="b64_json", ) def _extract_image_bytes(response: ImagesResponse, model: str) -> bytes: """Extract image bytes from an OpenAI images response.""" if not response.data: msg = f"OpenAI {model} returned no images" raise RuntimeError(msg) image = response.data[0] if image.b64_json is not None: return base64.b64decode(image.b64_json) if image.url is not None: resp = httpx.get(image.url) _ = resp.raise_for_status() return resp.content msg = f"OpenAI {model} returned neither b64_json nor url" raise RuntimeError(msg)