hokusai/bulkgen/providers/text.py

114 lines
3.5 KiB
Python

"""Mistral text generation provider."""
from __future__ import annotations
import base64
import mimetypes
from pathlib import Path
from typing import override
from mistralai import Mistral, models
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
from bulkgen.providers import Provider
from bulkgen.providers.models import ModelInfo
def _image_to_data_url(path: Path) -> str:
"""Read an image file and return a ``data:`` URL with base64-encoded content."""
mime = mimetypes.guess_type(path.name)[0] or "image/png"
b64 = base64.b64encode(path.read_bytes()).decode("ascii")
return f"data:{mime};base64,{b64}"
class TextProvider(Provider):
"""Generates text via the Mistral API."""
_api_key: str
def __init__(self, api_key: str) -> None:
self._api_key = api_key
@override
async def generate(
self,
target_name: str,
target_config: TargetConfig,
resolved_prompt: str,
resolved_model: ModelInfo,
project_dir: Path,
) -> None:
output_path = project_dir / target_name
all_input_names = list(target_config.inputs) + list(
target_config.reference_images
)
has_images = any(
(project_dir / name).suffix.lower() in IMAGE_EXTENSIONS
for name in all_input_names
)
if has_images:
message = _build_multimodal_message(
resolved_prompt, all_input_names, project_dir
)
else:
message = _build_text_message(resolved_prompt, all_input_names, project_dir)
async with Mistral(api_key=self._api_key) as client:
response = await client.chat.complete_async(
model=resolved_model.name,
messages=[message],
)
if not response.choices:
msg = f"Mistral API returned no choices for target '{target_name}'"
raise RuntimeError(msg)
content = response.choices[0].message.content
if content is None:
msg = f"Mistral API returned empty content for target '{target_name}'"
raise RuntimeError(msg)
text = content if isinstance(content, str) else str(content)
_ = output_path.write_text(text)
def _build_text_message(
prompt: str,
input_names: list[str],
project_dir: Path,
) -> models.UserMessage:
"""Build a plain-text message (no images)."""
parts: list[str] = [prompt]
for name in input_names:
file_content = (project_dir / name).read_text()
parts.append(f"\n--- Contents of {name} ---\n{file_content}")
return models.UserMessage(content="\n".join(parts))
def _build_multimodal_message(
prompt: str,
input_names: list[str],
project_dir: Path,
) -> models.UserMessage:
"""Build a multimodal message with text and image chunks."""
chunks: list[models.TextChunk | models.ImageURLChunk] = [
models.TextChunk(text=prompt),
]
for name in input_names:
input_path = project_dir / name
suffix = input_path.suffix.lower()
if suffix in IMAGE_EXTENSIONS:
data_url = _image_to_data_url(input_path)
chunks.append(models.ImageURLChunk(image_url=models.ImageURL(url=data_url)))
else:
file_content = input_path.read_text()
chunks.append(
models.TextChunk(text=f"\n--- Contents of {name} ---\n{file_content}")
)
return models.UserMessage(content=chunks) # pyright: ignore[reportArgumentType]