refactor: pass ModelInfo instead of model name string through provider interface
This commit is contained in:
parent
8e3ed7010f
commit
d15444bdb0
8 changed files with 83 additions and 43 deletions
|
|
@ -121,7 +121,7 @@ async def _build_single_target(
|
|||
"""Build a single target by dispatching to the appropriate provider."""
|
||||
target_cfg = config.targets[target_name]
|
||||
target_type = infer_target_type(target_name)
|
||||
model = resolve_model(target_name, target_cfg, config.defaults)
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
|
||||
provider = providers[target_type]
|
||||
|
|
@ -129,7 +129,7 @@ async def _build_single_target(
|
|||
target_name=target_name,
|
||||
target_config=target_cfg,
|
||||
resolved_prompt=resolved_prompt,
|
||||
resolved_model=model,
|
||||
resolved_model=model_info,
|
||||
project_dir=project_dir,
|
||||
)
|
||||
|
||||
|
|
@ -218,7 +218,7 @@ def _is_dirty(
|
|||
) -> bool:
|
||||
"""Check if a target needs rebuilding."""
|
||||
target_cfg = config.targets[target_name]
|
||||
model = resolve_model(target_name, target_cfg, config.defaults)
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
dep_files = _collect_dep_files(target_name, config, project_dir)
|
||||
extra = _collect_extra_params(target_name, config)
|
||||
|
|
@ -226,7 +226,7 @@ def _is_dirty(
|
|||
return is_target_dirty(
|
||||
target_name,
|
||||
resolved_prompt=resolved_prompt,
|
||||
model=model,
|
||||
model=model_info.name,
|
||||
dep_files=dep_files,
|
||||
extra_params=extra,
|
||||
state=state,
|
||||
|
|
@ -286,7 +286,7 @@ def _process_outcomes(
|
|||
on_progress(BuildEvent.TARGET_FAILED, name, str(error))
|
||||
else:
|
||||
target_cfg = config.targets[name]
|
||||
model = resolve_model(name, target_cfg, config.defaults)
|
||||
model_info = resolve_model(name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
dep_files = _collect_dep_files(name, config, project_dir)
|
||||
extra = _collect_extra_params(name, config)
|
||||
|
|
@ -294,7 +294,7 @@ def _process_outcomes(
|
|||
record_target_state(
|
||||
name,
|
||||
resolved_prompt=resolved_prompt,
|
||||
model=model,
|
||||
model=model_info.name,
|
||||
dep_files=dep_files,
|
||||
extra_params=extra,
|
||||
state=state,
|
||||
|
|
|
|||
|
|
@ -4,11 +4,14 @@ from __future__ import annotations
|
|||
|
||||
import enum
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
from typing import TYPE_CHECKING, Self
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
|
||||
|
||||
class TargetType(enum.Enum):
|
||||
"""The kind of artifact a target produces."""
|
||||
|
|
@ -65,14 +68,29 @@ def infer_target_type(target_name: str) -> TargetType:
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def resolve_model(target_name: str, target: TargetConfig, defaults: Defaults) -> str:
|
||||
"""Return the effective model for a target (explicit or default by type)."""
|
||||
def resolve_model(
|
||||
target_name: str, target: TargetConfig, defaults: Defaults
|
||||
) -> ModelInfo:
|
||||
"""Return the effective model for a target (explicit or default by type).
|
||||
|
||||
Raises :class:`ValueError` if the resolved model name is not in the registry.
|
||||
"""
|
||||
from bulkgen.providers.models import ALL_MODELS
|
||||
|
||||
if target.model is not None:
|
||||
return target.model
|
||||
target_type = infer_target_type(target_name)
|
||||
if target_type is TargetType.IMAGE:
|
||||
return defaults.image_model
|
||||
return defaults.text_model
|
||||
model_name = target.model
|
||||
else:
|
||||
target_type = infer_target_type(target_name)
|
||||
model_name = (
|
||||
defaults.image_model
|
||||
if target_type is TargetType.IMAGE
|
||||
else defaults.text_model
|
||||
)
|
||||
for model in ALL_MODELS:
|
||||
if model.name == model_name:
|
||||
return model
|
||||
msg = f"Unknown model '{model_name}' for target '{target_name}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def load_config(config_path: Path) -> ProjectConfig:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import abc
|
|||
from pathlib import Path
|
||||
|
||||
from bulkgen.config import TargetConfig
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
|
||||
|
||||
class Provider(abc.ABC):
|
||||
|
|
@ -17,7 +18,7 @@ class Provider(abc.ABC):
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
"""Generate the target artifact and write it to *project_dir / target_name*.
|
||||
|
|
@ -26,6 +27,6 @@ class Provider(abc.ABC):
|
|||
target_name: Output filename (relative to project_dir).
|
||||
target_config: The parsed target configuration.
|
||||
resolved_prompt: The fully resolved prompt text.
|
||||
resolved_model: The resolved model name.
|
||||
resolved_model: The resolved model information.
|
||||
project_dir: The project working directory.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import httpx
|
|||
from bulkgen.config import TargetConfig
|
||||
from bulkgen.providers import Provider
|
||||
from bulkgen.providers.bfl import BFLClient
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
|
||||
|
||||
def _encode_image_b64(path: Path) -> str:
|
||||
|
|
@ -58,7 +59,7 @@ class ImageProvider(Provider):
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
output_path = project_dir / target_name
|
||||
|
|
@ -72,14 +73,14 @@ class ImageProvider(Provider):
|
|||
|
||||
if target_config.reference_images:
|
||||
_add_reference_images(
|
||||
inputs, target_config.reference_images, resolved_model, project_dir
|
||||
inputs, target_config.reference_images, resolved_model.name, project_dir
|
||||
)
|
||||
|
||||
for control_name in target_config.control_images:
|
||||
ctrl_path = project_dir / control_name
|
||||
inputs["control_image"] = _encode_image_b64(ctrl_path)
|
||||
|
||||
result = await self._client.generate(resolved_model, inputs)
|
||||
result = await self._client.generate(resolved_model.name, inputs)
|
||||
|
||||
async with httpx.AsyncClient() as http:
|
||||
response = await http.get(result.sample_url)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from mistralai import Mistral, models
|
|||
|
||||
from bulkgen.config import IMAGE_EXTENSIONS, TargetConfig
|
||||
from bulkgen.providers import Provider
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
|
||||
|
||||
def _image_to_data_url(path: Path) -> str:
|
||||
|
|
@ -34,7 +35,7 @@ class TextProvider(Provider):
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
output_path = project_dir / target_name
|
||||
|
|
@ -57,7 +58,7 @@ class TextProvider(Provider):
|
|||
|
||||
async with Mistral(api_key=self._api_key) as client:
|
||||
response = await client.chat.complete_async(
|
||||
model=resolved_model,
|
||||
model=resolved_model.name,
|
||||
messages=[message],
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue