fix: pass all reference images to OpenAI images.edit endpoint
This commit is contained in:
parent
61f30a8bb1
commit
3bfad87dce
2 changed files with 99 additions and 9 deletions
|
|
@ -185,25 +185,25 @@ async def _generate_edit(
|
||||||
project_dir: Path,
|
project_dir: Path,
|
||||||
size: _SIZE | None,
|
size: _SIZE | None,
|
||||||
) -> ImagesResponse:
|
) -> ImagesResponse:
|
||||||
"""Generate an image using a reference image via the edits endpoint.
|
"""Generate an image using reference images via the edits endpoint.
|
||||||
|
|
||||||
gpt-image-* models return b64 by default and reject ``response_format``,
|
gpt-image-* models accept up to 16 images and return b64 by default
|
||||||
so we only pass it for DALL-E models.
|
(they reject ``response_format``). DALL-E 2 accepts only one image.
|
||||||
"""
|
"""
|
||||||
ref_path = project_dir / reference_images[0]
|
images = [(project_dir / name).read_bytes() for name in reference_images]
|
||||||
image_bytes = ref_path.read_bytes()
|
image: bytes | list[bytes] = images[0] if len(images) == 1 else images
|
||||||
|
|
||||||
if model.startswith("gpt-image-"):
|
if model.startswith("gpt-image-"):
|
||||||
if size is not None:
|
if size is not None:
|
||||||
return await client.images.edit(
|
return await client.images.edit(
|
||||||
image=image_bytes,
|
image=image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
n=1,
|
n=1,
|
||||||
size=size, # pyright: ignore[reportArgumentType]
|
size=size, # pyright: ignore[reportArgumentType]
|
||||||
)
|
)
|
||||||
return await client.images.edit(
|
return await client.images.edit(
|
||||||
image=image_bytes,
|
image=image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
n=1,
|
n=1,
|
||||||
|
|
@ -211,7 +211,7 @@ async def _generate_edit(
|
||||||
|
|
||||||
if size is not None:
|
if size is not None:
|
||||||
return await client.images.edit(
|
return await client.images.edit(
|
||||||
image=image_bytes,
|
image=image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
n=1,
|
n=1,
|
||||||
|
|
@ -219,7 +219,7 @@ async def _generate_edit(
|
||||||
size=size, # pyright: ignore[reportArgumentType]
|
size=size, # pyright: ignore[reportArgumentType]
|
||||||
)
|
)
|
||||||
return await client.images.edit(
|
return await client.images.edit(
|
||||||
image=image_bytes,
|
image=image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
n=1,
|
n=1,
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from bulkgen.providers.blackforest import (
|
||||||
)
|
)
|
||||||
from bulkgen.providers.mistral import MistralProvider
|
from bulkgen.providers.mistral import MistralProvider
|
||||||
from bulkgen.providers.models import ModelInfo
|
from bulkgen.providers.models import ModelInfo
|
||||||
|
from bulkgen.providers.openai_image import OpenAIImageProvider
|
||||||
from bulkgen.providers.registry import get_all_models
|
from bulkgen.providers.registry import get_all_models
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -422,3 +423,92 @@ class TestMistralProvider:
|
||||||
assert isinstance(chunks, list)
|
assert isinstance(chunks, list)
|
||||||
assert chunks[0].text == "Describe the style"
|
assert chunks[0].text == "Describe the style"
|
||||||
assert chunks[1].image_url.url.startswith("data:image/png;base64,")
|
assert chunks[1].image_url.url.startswith("data:image/png;base64,")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_openai_mock(b64_data: str) -> AsyncMock:
|
||||||
|
"""Return a mock AsyncOpenAI client that returns b64 image data."""
|
||||||
|
image = MagicMock()
|
||||||
|
image.b64_json = b64_data
|
||||||
|
image.url = None
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.data = [image]
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.images.generate = AsyncMock(return_value=response)
|
||||||
|
mock_client.images.edit = AsyncMock(return_value=response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIImageProvider:
|
||||||
|
"""Test OpenAIImageProvider with mocked OpenAI client."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def image_bytes(self) -> bytes:
|
||||||
|
return b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||||
|
|
||||||
|
async def test_single_reference_image(
|
||||||
|
self, project_dir: Path, image_bytes: bytes
|
||||||
|
) -> None:
|
||||||
|
ref = project_dir / "ref.png"
|
||||||
|
_ = ref.write_bytes(b"reference data")
|
||||||
|
|
||||||
|
target_config = TargetConfig(prompt="Edit this", reference_images=["ref.png"])
|
||||||
|
b64 = base64.b64encode(image_bytes).decode()
|
||||||
|
mock_client = _make_openai_mock(b64)
|
||||||
|
|
||||||
|
with patch("bulkgen.providers.openai_image.AsyncOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = mock_client
|
||||||
|
|
||||||
|
provider = OpenAIImageProvider(api_key="test-key")
|
||||||
|
await provider.generate(
|
||||||
|
target_name="out.png",
|
||||||
|
target_config=target_config,
|
||||||
|
resolved_prompt="Edit this",
|
||||||
|
resolved_model=_model("gpt-image-1"),
|
||||||
|
project_dir=project_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_client.images.edit.call_args
|
||||||
|
# Single reference image should be passed as raw bytes
|
||||||
|
assert call_args.kwargs["image"] == b"reference data"
|
||||||
|
|
||||||
|
output = project_dir / "out.png"
|
||||||
|
assert output.exists()
|
||||||
|
assert output.read_bytes() == image_bytes
|
||||||
|
|
||||||
|
async def test_multiple_reference_images(
|
||||||
|
self, project_dir: Path, image_bytes: bytes
|
||||||
|
) -> None:
|
||||||
|
ref1 = project_dir / "ref1.png"
|
||||||
|
ref2 = project_dir / "ref2.png"
|
||||||
|
_ = ref1.write_bytes(b"ref1 data")
|
||||||
|
_ = ref2.write_bytes(b"ref2 data")
|
||||||
|
|
||||||
|
target_config = TargetConfig(
|
||||||
|
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
|
||||||
|
)
|
||||||
|
b64 = base64.b64encode(image_bytes).decode()
|
||||||
|
mock_client = _make_openai_mock(b64)
|
||||||
|
|
||||||
|
with patch("bulkgen.providers.openai_image.AsyncOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = mock_client
|
||||||
|
|
||||||
|
provider = OpenAIImageProvider(api_key="test-key")
|
||||||
|
await provider.generate(
|
||||||
|
target_name="out.png",
|
||||||
|
target_config=target_config,
|
||||||
|
resolved_prompt="Combine",
|
||||||
|
resolved_model=_model("gpt-image-1"),
|
||||||
|
project_dir=project_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_client.images.edit.call_args
|
||||||
|
# Multiple reference images should be passed as a list of bytes
|
||||||
|
image_arg: list[bytes] = call_args.kwargs["image"]
|
||||||
|
assert isinstance(image_arg, list)
|
||||||
|
assert len(image_arg) == 2
|
||||||
|
assert image_arg[0] == b"ref1 data"
|
||||||
|
assert image_arg[1] == b"ref2 data"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue