fix: pass all reference images to OpenAI images.edit endpoint
This commit is contained in:
parent
61f30a8bb1
commit
3bfad87dce
2 changed files with 99 additions and 9 deletions
|
|
@ -20,6 +20,7 @@ from bulkgen.providers.blackforest import (
|
|||
)
|
||||
from bulkgen.providers.mistral import MistralProvider
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.providers.openai_image import OpenAIImageProvider
|
||||
from bulkgen.providers.registry import get_all_models
|
||||
|
||||
|
||||
|
|
@ -422,3 +423,92 @@ class TestMistralProvider:
|
|||
assert isinstance(chunks, list)
|
||||
assert chunks[0].text == "Describe the style"
|
||||
assert chunks[1].image_url.url.startswith("data:image/png;base64,")
|
||||
|
||||
|
||||
def _make_openai_mock(b64_data: str) -> AsyncMock:
|
||||
"""Return a mock AsyncOpenAI client that returns b64 image data."""
|
||||
image = MagicMock()
|
||||
image.b64_json = b64_data
|
||||
image.url = None
|
||||
|
||||
response = MagicMock()
|
||||
response.data = [image]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.images.generate = AsyncMock(return_value=response)
|
||||
mock_client.images.edit = AsyncMock(return_value=response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_client
|
||||
|
||||
|
||||
class TestOpenAIImageProvider:
|
||||
"""Test OpenAIImageProvider with mocked OpenAI client."""
|
||||
|
||||
@pytest.fixture
|
||||
def image_bytes(self) -> bytes:
|
||||
return b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
|
||||
async def test_single_reference_image(
|
||||
self, project_dir: Path, image_bytes: bytes
|
||||
) -> None:
|
||||
ref = project_dir / "ref.png"
|
||||
_ = ref.write_bytes(b"reference data")
|
||||
|
||||
target_config = TargetConfig(prompt="Edit this", reference_images=["ref.png"])
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
mock_client = _make_openai_mock(b64)
|
||||
|
||||
with patch("bulkgen.providers.openai_image.AsyncOpenAI") as mock_cls:
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
provider = OpenAIImageProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Edit this",
|
||||
resolved_model=_model("gpt-image-1"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
call_args = mock_client.images.edit.call_args
|
||||
# Single reference image should be passed as raw bytes
|
||||
assert call_args.kwargs["image"] == b"reference data"
|
||||
|
||||
output = project_dir / "out.png"
|
||||
assert output.exists()
|
||||
assert output.read_bytes() == image_bytes
|
||||
|
||||
async def test_multiple_reference_images(
|
||||
self, project_dir: Path, image_bytes: bytes
|
||||
) -> None:
|
||||
ref1 = project_dir / "ref1.png"
|
||||
ref2 = project_dir / "ref2.png"
|
||||
_ = ref1.write_bytes(b"ref1 data")
|
||||
_ = ref2.write_bytes(b"ref2 data")
|
||||
|
||||
target_config = TargetConfig(
|
||||
prompt="Combine", reference_images=["ref1.png", "ref2.png"]
|
||||
)
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
mock_client = _make_openai_mock(b64)
|
||||
|
||||
with patch("bulkgen.providers.openai_image.AsyncOpenAI") as mock_cls:
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
provider = OpenAIImageProvider(api_key="test-key")
|
||||
await provider.generate(
|
||||
target_name="out.png",
|
||||
target_config=target_config,
|
||||
resolved_prompt="Combine",
|
||||
resolved_model=_model("gpt-image-1"),
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
call_args = mock_client.images.edit.call_args
|
||||
# Multiple reference images should be passed as a list of bytes
|
||||
image_arg: list[bytes] = call_args.kwargs["image"]
|
||||
assert isinstance(image_arg, list)
|
||||
assert len(image_arg) == 2
|
||||
assert image_arg[0] == b"ref1 data"
|
||||
assert image_arg[1] == b"ref2 data"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue