"""OpenAI image generation provider.""" from __future__ import annotations import base64 from pathlib import Path from typing import Literal, override import httpx from openai import AsyncOpenAI from openai.types.images_response import ImagesResponse from bulkgen.config import TargetConfig from bulkgen.providers import Provider from bulkgen.providers.models import Capability, ModelInfo _SIZE = Literal[ "auto", "1024x1024", "1024x1536", "1536x1024", "1024x1792", "1792x1024", "256x256", "512x512", ] _VALID_SIZES: frozenset[str] = frozenset( { "auto", "1024x1024", "1024x1536", "1536x1024", "1024x1792", "1792x1024", "256x256", "512x512", } ) def _build_size(width: int | None, height: int | None) -> _SIZE | None: """Convert width/height to an OpenAI size string, or *None* for the default.""" if width is None and height is None: return None w = width or 1024 h = height or 1024 size = f"{w}x{h}" if size not in _VALID_SIZES: msg = f"Unsupported OpenAI image size '{size}'. Valid sizes: {', '.join(sorted(_VALID_SIZES))}" raise ValueError(msg) return size # pyright: ignore[reportReturnType] class OpenAIImageProvider(Provider): """Generates images via the OpenAI API.""" _api_key: str def __init__(self, api_key: str) -> None: self._api_key = api_key @staticmethod @override def get_provided_models() -> list[ModelInfo]: return [ ModelInfo( name="gpt-image-1.5", provider="OpenAI", type="image", capabilities=[ Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES, ], ), ModelInfo( name="gpt-image-1", provider="OpenAI", type="image", capabilities=[ Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES, ], ), ModelInfo( name="gpt-image-1-mini", provider="OpenAI", type="image", capabilities=[ Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES, ], ), ModelInfo( name="dall-e-3", provider="OpenAI", type="image", capabilities=[Capability.TEXT_TO_IMAGE], ), ModelInfo( name="dall-e-2", provider="OpenAI", type="image", capabilities=[Capability.TEXT_TO_IMAGE], ), ] @override async def generate( self, target_name: str, target_config: TargetConfig, resolved_prompt: str, resolved_model: ModelInfo, project_dir: Path, ) -> None: output_path = project_dir / target_name size = _build_size(target_config.width, target_config.height) async with AsyncOpenAI(api_key=self._api_key) as client: if target_config.reference_images: response = await _generate_edit( client, resolved_prompt, resolved_model.name, target_config.reference_images, project_dir, size, ) else: response = await _generate_new( client, resolved_prompt, resolved_model.name, size, ) image_data = _extract_image_bytes(response, resolved_model.name) _ = output_path.write_bytes(image_data) async def _generate_new( client: AsyncOpenAI, prompt: str, model: str, size: _SIZE | None, ) -> ImagesResponse: """Generate a new image from a text prompt. gpt-image-* models return b64 by default and reject ``response_format``, so we only pass it for DALL-E models. """ # gpt-image-* returns b64 by default; DALL-E defaults to url. if model.startswith("gpt-image-"): if size is not None: return await client.images.generate( prompt=prompt, model=model, n=1, size=size, ) return await client.images.generate(prompt=prompt, model=model, n=1) if size is not None: return await client.images.generate( prompt=prompt, model=model, n=1, response_format="b64_json", size=size, ) return await client.images.generate( prompt=prompt, model=model, n=1, response_format="b64_json", ) async def _generate_edit( client: AsyncOpenAI, prompt: str, model: str, reference_images: list[str], project_dir: Path, size: _SIZE | None, ) -> ImagesResponse: """Generate an image using reference images via the edits endpoint. gpt-image-* models accept up to 16 images and return b64 by default (they reject ``response_format``). DALL-E 2 accepts only one image. """ images = [(project_dir / name).read_bytes() for name in reference_images] image: bytes | list[bytes] = images[0] if len(images) == 1 else images if model.startswith("gpt-image-"): if size is not None: return await client.images.edit( image=image, prompt=prompt, model=model, n=1, size=size, # pyright: ignore[reportArgumentType] ) return await client.images.edit( image=image, prompt=prompt, model=model, n=1, ) if size is not None: return await client.images.edit( image=image, prompt=prompt, model=model, n=1, response_format="b64_json", size=size, # pyright: ignore[reportArgumentType] ) return await client.images.edit( image=image, prompt=prompt, model=model, n=1, response_format="b64_json", ) def _extract_image_bytes(response: ImagesResponse, model: str) -> bytes: """Extract image bytes from an OpenAI images response.""" if not response.data: msg = f"OpenAI {model} returned no images" raise RuntimeError(msg) image = response.data[0] if image.b64_json is not None: return base64.b64decode(image.b64_json) if image.url is not None: resp = httpx.get(image.url) _ = resp.raise_for_status() return resp.content msg = f"OpenAI {model} returned neither b64_json nor url" raise RuntimeError(msg)