refactor: pass ModelInfo instead of model name string through provider interface
This commit is contained in:
parent
8e3ed7010f
commit
d15444bdb0
8 changed files with 83 additions and 43 deletions
|
|
@ -18,6 +18,7 @@ from bulkgen.builder import (
|
|||
)
|
||||
from bulkgen.config import ProjectConfig, TargetConfig, TargetType
|
||||
from bulkgen.providers import Provider
|
||||
from bulkgen.providers.models import ModelInfo
|
||||
from bulkgen.state import load_state
|
||||
|
||||
WriteConfig = Callable[[dict[str, object]], ProjectConfig]
|
||||
|
|
@ -32,7 +33,7 @@ class FakeProvider(Provider):
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
output = project_dir / target_name
|
||||
|
|
@ -48,7 +49,7 @@ class FailingProvider(Provider):
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
msg = f"Simulated failure for {target_name}"
|
||||
|
|
@ -247,7 +248,7 @@ class TestRunBuild:
|
|||
target_name: str,
|
||||
target_config: TargetConfig,
|
||||
resolved_prompt: str,
|
||||
resolved_model: str,
|
||||
resolved_model: ModelInfo,
|
||||
project_dir: Path,
|
||||
) -> None:
|
||||
if target_name == "fail.txt":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue