feat: add OpenAI as provider for text and image generation
- Add openai_text.py: text generation via OpenAI chat completions API (gpt-4o, gpt-4o-mini, gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, o3-mini) - Add openai_image.py: image generation via OpenAI images API (gpt-image-1 with reference image support, dall-e-3, dall-e-2) - Refactor builder provider dispatch from TargetType to model-name index to support multiple providers per target type - Fix circular import between config.py and providers/__init__.py using TYPE_CHECKING guard - Fix stale default model assertions in tests - Add openai>=1.0.0 dependency
This commit is contained in:
parent
d0dac5b1bf
commit
870023865d
9 changed files with 571 additions and 58 deletions
|
|
@ -16,21 +16,40 @@ from bulkgen.builder import (
|
|||
_resolve_prompt, # pyright: ignore[reportPrivateUsage]
|
||||
run_build,
|
||||
)
|
||||
from bulkgen.config import ProjectConfig, TargetConfig, TargetType
|
||||
from bulkgen.config import ProjectConfig, TargetConfig
|
||||
from bulkgen.providers import Provider
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.providers.models import Capability, ModelInfo
|
||||
from bulkgen.state import load_state
|
||||
|
||||
WriteConfig = Callable[[dict[str, object]], ProjectConfig]
|
||||
|
||||
|
||||
class FakeProvider(Provider):
|
||||
"""A provider that writes a marker file instead of calling an API."""
|
||||
_FAKE_TEXT_MODELS = [
|
||||
ModelInfo(
|
||||
name="pixtral-large-latest",
|
||||
provider="Fake",
|
||||
type="text",
|
||||
capabilities=[Capability.TEXT_GENERATION, Capability.VISION],
|
||||
),
|
||||
]
|
||||
|
||||
_FAKE_IMAGE_MODELS = [
|
||||
ModelInfo(
|
||||
name="flux-2-pro",
|
||||
provider="Fake",
|
||||
type="image",
|
||||
capabilities=[Capability.TEXT_TO_IMAGE, Capability.REFERENCE_IMAGES],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class FakeTextProvider(Provider):
|
||||
"""A text provider that writes a marker file instead of calling an API."""
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_provided_models() -> list[ModelInfo]:
|
||||
return []
|
||||
return _FAKE_TEXT_MODELS
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
|
|
@ -45,13 +64,34 @@ class FakeProvider(Provider):
|
|||
_ = output.write_text(f"generated:{target_name}:{resolved_prompt}")
|
||||
|
||||
|
||||
class FailingProvider(Provider):
|
||||
"""A provider that always raises."""
|
||||
class FakeImageProvider(Provider):
|
||||
"""An image provider that writes a marker file instead of calling an API."""
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_provided_models() -> list[ModelInfo]:
|
||||
return []
|
||||
return _FAKE_IMAGE_MODELS
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
output = project_dir / target_name
|
||||
_ = output.write_text(f"generated:{target_name}:{resolved_prompt}")
|
||||
|
||||
|
||||
class FailingTextProvider(Provider):
|
||||
"""A text provider that always raises."""
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_provided_models() -> list[ModelInfo]:
|
||||
return _FAKE_TEXT_MODELS
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
|
|
@ -66,11 +106,8 @@ class FailingProvider(Provider):
|
|||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def _fake_providers() -> dict[TargetType, Provider]:
|
||||
return {
|
||||
TargetType.TEXT: FakeProvider(),
|
||||
TargetType.IMAGE: FakeProvider(),
|
||||
}
|
||||
def _fake_providers() -> list[Provider]:
|
||||
return [FakeTextProvider(), FakeImageProvider()]
|
||||
|
||||
|
||||
class TestResolvePrompt:
|
||||
|
|
@ -251,8 +288,8 @@ class TestRunBuild:
|
|||
}
|
||||
}
|
||||
)
|
||||
fail_provider = FailingProvider()
|
||||
fake_provider = FakeProvider()
|
||||
fail_provider = FailingTextProvider()
|
||||
fake_provider = FakeTextProvider()
|
||||
|
||||
async def selective_generate(
|
||||
target_name: str,
|
||||
|
|
@ -278,15 +315,13 @@ class TestRunBuild:
|
|||
project_dir,
|
||||
)
|
||||
|
||||
routing_provider = FakeProvider()
|
||||
routing_provider = FakeTextProvider()
|
||||
routing_provider.generate = selective_generate # type: ignore[assignment]
|
||||
|
||||
providers_dict: dict[TargetType, Provider] = {
|
||||
TargetType.TEXT: routing_provider,
|
||||
TargetType.IMAGE: routing_provider,
|
||||
}
|
||||
|
||||
with patch("bulkgen.builder._create_providers", return_value=providers_dict):
|
||||
with patch(
|
||||
"bulkgen.builder._create_providers",
|
||||
return_value=[routing_provider, FakeImageProvider()],
|
||||
):
|
||||
result = await run_build(config, project_dir)
|
||||
|
||||
assert "fail.txt" in result.failed
|
||||
|
|
@ -304,11 +339,10 @@ class TestRunBuild:
|
|||
}
|
||||
)
|
||||
|
||||
with patch("bulkgen.builder._create_providers") as mock_cp:
|
||||
mock_cp.return_value = {
|
||||
TargetType.TEXT: FailingProvider(),
|
||||
TargetType.IMAGE: FakeProvider(),
|
||||
}
|
||||
with patch(
|
||||
"bulkgen.builder._create_providers",
|
||||
return_value=[FailingTextProvider(), FakeImageProvider()],
|
||||
):
|
||||
result = await run_build(config, project_dir)
|
||||
|
||||
assert "base.txt" in result.failed
|
||||
|
|
@ -320,12 +354,12 @@ class TestRunBuild:
|
|||
) -> None:
|
||||
with patch(
|
||||
"bulkgen.builder._create_providers",
|
||||
return_value={},
|
||||
return_value=[],
|
||||
):
|
||||
result = await run_build(simple_text_config, project_dir)
|
||||
|
||||
assert "output.txt" in result.failed
|
||||
assert "MISTRAL_API_KEY" in result.failed["output.txt"]
|
||||
assert "No provider available" in result.failed["output.txt"]
|
||||
|
||||
async def test_state_saved_after_each_generation(
|
||||
self, project_dir: Path, write_config: WriteConfig
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue