fix: pass all reference images to OpenAI images.edit endpoint
This commit is contained in:
parent
61f30a8bb1
commit
3bfad87dce
2 changed files with 99 additions and 9 deletions
|
|
@ -185,25 +185,25 @@ async def _generate_edit(
|
|||
project_dir: Path,
|
||||
size: _SIZE | None,
|
||||
) -> ImagesResponse:
|
||||
"""Generate an image using a reference image via the edits endpoint.
|
||||
"""Generate an image using reference images via the edits endpoint.
|
||||
|
||||
gpt-image-* models return b64 by default and reject ``response_format``,
|
||||
so we only pass it for DALL-E models.
|
||||
gpt-image-* models accept up to 16 images and return b64 by default
|
||||
(they reject ``response_format``). DALL-E 2 accepts only one image.
|
||||
"""
|
||||
ref_path = project_dir / reference_images[0]
|
||||
image_bytes = ref_path.read_bytes()
|
||||
images = [(project_dir / name).read_bytes() for name in reference_images]
|
||||
image: bytes | list[bytes] = images[0] if len(images) == 1 else images
|
||||
|
||||
if model.startswith("gpt-image-"):
|
||||
if size is not None:
|
||||
return await client.images.edit(
|
||||
image=image_bytes,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
n=1,
|
||||
size=size, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
return await client.images.edit(
|
||||
image=image_bytes,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
n=1,
|
||||
|
|
@ -211,7 +211,7 @@ async def _generate_edit(
|
|||
|
||||
if size is not None:
|
||||
return await client.images.edit(
|
||||
image=image_bytes,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
n=1,
|
||||
|
|
@ -219,7 +219,7 @@ async def _generate_edit(
|
|||
size=size, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
return await client.images.edit(
|
||||
image=image_bytes,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
n=1,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from bulkgen.providers.blackforest import (
|
|||
)
|
||||
from bulkgen.providers.mistral import MistralProvider
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.providers.openai_image import OpenAIImageProvider
|
||||
from bulkgen.providers.registry import get_all_models
|
||||
|
||||
|
||||
|
|
@ -422,3 +423,92 @@ class TestMistralProvider:
|
|||
assert isinstance(chunks, list)
|
||||
assert chunks[0].text == "Describe the style"
|
||||
assert chunks[1].image_url.url.startswith("data:image/png;base64,")
|
||||
|
||||
|
||||
def _make_openai_mock(b64_data: str) -> AsyncMock:
|
||||
"""Return a mock AsyncOpenAI client that returns b64 image data."""
|
||||
image = MagicMock()
|
||||
image.b64_json = b64_data
|
||||
image.url = None
|
||||
|
||||
response = MagicMock()
|
||||
response.data = [image]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.images.generate = AsyncMock(return_value=response)
|
||||
mock_client.images.edit = AsyncMock(return_value=response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_client
|
||||
|
||||
|
||||
class TestOpenAIImageProvider:
|
||||
"""Test OpenAIImageProvider with mocked OpenAI client."""
|
||||
|
||||
@pytest.fixture
|
||||
def image_bytes(self) -> bytes:
|
||||
return b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
|
||||
async def test_single_reference_image(
|
||||
self, project_dir: Path, image_bytes: bytes
|
||||
) -> None:
|
||||
ref = project_dir / "ref.png"
|
||||
_ = ref.write_bytes(b"reference data")
|
||||
|
||||
target_config = TargetConfig(prompt="Edit this", reference_images=["ref.png"])
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
mock_client = _make_openai_mock(b64)
|
||||
|
||||
with patch("bulkgen.providers.openai_image.AsyncOpenAI") as mock_cls:
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
provider = OpenAIImageProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Edit this",
|
||||
resolved_model=_model("gpt-image-1"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
call_args = mock_client.images.edit.call_args
|
||||
# Single reference image should be passed as raw bytes
|
||||
assert call_args.kwargs["image"] == b"reference data"
|
||||
|
||||
output = project_dir / "out.png"
|
||||
assert output.exists()
|
||||
assert output.read_bytes() == image_bytes
|
||||
|
||||
async def test_multiple_reference_images(
|
||||
self, project_dir: Path, image_bytes: bytes
|
||||
) -> None:
|
||||
ref1 = project_dir / "ref1.png"
|
||||
ref2 = project_dir / "ref2.png"
|
||||
_ = ref1.write_bytes(b"ref1 data")
|
||||
_ = ref2.write_bytes(b"ref2 data")
|
||||
|
||||
target_config = TargetConfig(
|
||||
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
|
||||
)
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
mock_client = _make_openai_mock(b64)
|
||||
|
||||
with patch("bulkgen.providers.openai_image.AsyncOpenAI") as mock_cls:
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
provider = OpenAIImageProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Combine",
|
||||
resolved_model=_model("gpt-image-1"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
call_args = mock_client.images.edit.call_args
|
||||
# Multiple reference images should be passed as a list of bytes
|
||||
image_arg: list[bytes] = call_args.kwargs["image"]
|
||||
assert isinstance(image_arg, list)
|
||||
assert len(image_arg) == 2
|
||||
assert image_arg[0] == b"ref1 data"
|
||||
assert image_arg[1] == b"ref2 data"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue