feat: replace infer_target_type with capability-based model selection and validation
This commit is contained in:
parent
d15444bdb0
commit
e7270a118d
3 changed files with 179 additions and 37 deletions
|
|
@ -12,8 +12,9 @@ from pathlib import Path
|
|||
from bulkgen.config import (
|
||||
ProjectConfig,
|
||||
TargetType,
|
||||
infer_target_type,
|
||||
infer_required_capabilities,
|
||||
resolve_model,
|
||||
target_type_from_capabilities,
|
||||
)
|
||||
from bulkgen.graph import build_graph, get_build_order, get_subgraph_for_target
|
||||
from bulkgen.providers import Provider
|
||||
|
|
@ -120,8 +121,9 @@ async def _build_single_target(
|
|||
) -> None:
|
||||
"""Build a single target by dispatching to the appropriate provider."""
|
||||
target_cfg = config.targets[target_name]
|
||||
target_type = infer_target_type(target_name)
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
required = infer_required_capabilities(target_name, target_cfg)
|
||||
target_type = target_type_from_capabilities(required)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
|
||||
provider = providers[target_type]
|
||||
|
|
@ -180,7 +182,7 @@ async def run_build(
|
|||
continue
|
||||
|
||||
if _is_dirty(name, config, project_dir, state):
|
||||
if not _has_provider(name, providers, result, on_progress):
|
||||
if not _has_provider(name, config, providers, result, on_progress):
|
||||
continue
|
||||
dirty_targets.append(name)
|
||||
else:
|
||||
|
|
@ -236,12 +238,15 @@ def _is_dirty(
|
|||
|
||||
def _has_provider(
|
||||
target_name: str,
|
||||
config: ProjectConfig,
|
||||
providers: dict[TargetType, Provider],
|
||||
result: BuildResult,
|
||||
on_progress: ProgressCallback = _noop_callback,
|
||||
) -> bool:
|
||||
"""Check that the required provider is available; record failure if not."""
|
||||
target_type = infer_target_type(target_name)
|
||||
target_cfg = config.targets[target_name]
|
||||
required = infer_required_capabilities(target_name, target_cfg)
|
||||
target_type = target_type_from_capabilities(required)
|
||||
if target_type not in providers:
|
||||
env_var = (
|
||||
"BFL_API_KEY" if target_type is TargetType.IMAGE else "MISTRAL_API_KEY"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue