feat: add prompt placeholder substitution with {filename} syntax
This commit is contained in:
parent
760eac5a7b
commit
3de3614433
6 changed files with 274 additions and 33 deletions
|
|
@ -13,6 +13,7 @@ import httpx
|
|||
|
||||
from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig
|
||||
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
|
||||
from hokusai.prompt import extract_placeholder_files, resolve_prompt
|
||||
from hokusai.providers import Provider
|
||||
from hokusai.providers.blackforest import BlackForestProvider
|
||||
from hokusai.providers.mistral import MistralProvider
|
||||
|
|
@ -57,19 +58,6 @@ class BuildResult:
|
|||
total_targets: int = 0
|
||||
|
||||
|
||||
def _resolve_prompt(prompt_value: str, project_dir: Path) -> str:
|
||||
"""Resolve a prompt: read from file if the path exists, otherwise use as-is."""
|
||||
if "\n" in prompt_value:
|
||||
return prompt_value
|
||||
try:
|
||||
candidate = project_dir / prompt_value
|
||||
if candidate.is_file():
|
||||
return candidate.read_text()
|
||||
except OSError:
|
||||
pass
|
||||
return prompt_value
|
||||
|
||||
|
||||
def _collect_dep_files(
|
||||
target_name: str, config: ProjectConfig, project_dir: Path
|
||||
) -> list[Path]:
|
||||
|
|
@ -80,6 +68,7 @@ def _collect_dep_files(
|
|||
deps: list[str] = list(target_cfg.inputs)
|
||||
deps.extend(target_cfg.reference_images)
|
||||
deps.extend(target_cfg.control_images)
|
||||
deps.extend(extract_placeholder_files(target_cfg.prompt))
|
||||
return [project_dir / d for d in deps]
|
||||
|
||||
|
||||
|
|
@ -101,13 +90,14 @@ def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str,
|
|||
|
||||
|
||||
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
|
||||
"""Collect all dependency names (inputs + reference_images + control_images)."""
|
||||
"""Collect all dependency names (inputs + reference_images + control_images + placeholders)."""
|
||||
target_cfg = config.targets[target_name]
|
||||
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||
return []
|
||||
deps: list[str] = list(target_cfg.inputs)
|
||||
deps.extend(target_cfg.reference_images)
|
||||
deps.extend(target_cfg.control_images)
|
||||
deps.extend(extract_placeholder_files(target_cfg.prompt))
|
||||
return deps
|
||||
|
||||
|
||||
|
|
@ -162,7 +152,7 @@ async def _build_single_target(
|
|||
return
|
||||
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
||||
|
||||
provider = provider_index[model_info.name]
|
||||
await provider.generate(
|
||||
|
|
@ -276,7 +266,7 @@ def _is_dirty(
|
|||
)
|
||||
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
||||
dep_files = _collect_dep_files(target_name, config, project_dir)
|
||||
extra = _collect_extra_params(target_name, config)
|
||||
|
||||
|
|
@ -350,7 +340,7 @@ def _process_outcomes(
|
|||
model_name = _DOWNLOAD_MODEL_SENTINEL
|
||||
else:
|
||||
model_info = resolve_model(name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
||||
model_name = model_info.name
|
||||
|
||||
dep_files = _collect_dep_files(name, config, project_dir)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue