feat: add prompt placeholder substitution with {filename} syntax
This commit is contained in:
parent
760eac5a7b
commit
3de3614433
6 changed files with 274 additions and 33 deletions
|
|
@ -13,6 +13,7 @@ import httpx
|
|||
|
||||
from hokusai.config import DownloadTargetConfig, GenerateTargetConfig, ProjectConfig
|
||||
from hokusai.graph import build_graph, get_build_order, get_subgraph_for_target
|
||||
from hokusai.prompt import extract_placeholder_files, resolve_prompt
|
||||
from hokusai.providers import Provider
|
||||
from hokusai.providers.blackforest import BlackForestProvider
|
||||
from hokusai.providers.mistral import MistralProvider
|
||||
|
|
@ -57,19 +58,6 @@ class BuildResult:
|
|||
total_targets: int = 0
|
||||
|
||||
|
||||
def _resolve_prompt(prompt_value: str, project_dir: Path) -> str:
|
||||
"""Resolve a prompt: read from file if the path exists, otherwise use as-is."""
|
||||
if "\n" in prompt_value:
|
||||
return prompt_value
|
||||
try:
|
||||
candidate = project_dir / prompt_value
|
||||
if candidate.is_file():
|
||||
return candidate.read_text()
|
||||
except OSError:
|
||||
pass
|
||||
return prompt_value
|
||||
|
||||
|
||||
def _collect_dep_files(
|
||||
target_name: str, config: ProjectConfig, project_dir: Path
|
||||
) -> list[Path]:
|
||||
|
|
@ -80,6 +68,7 @@ def _collect_dep_files(
|
|||
deps: list[str] = list(target_cfg.inputs)
|
||||
deps.extend(target_cfg.reference_images)
|
||||
deps.extend(target_cfg.control_images)
|
||||
deps.extend(extract_placeholder_files(target_cfg.prompt))
|
||||
return [project_dir / d for d in deps]
|
||||
|
||||
|
||||
|
|
@ -101,13 +90,14 @@ def _collect_extra_params(target_name: str, config: ProjectConfig) -> dict[str,
|
|||
|
||||
|
||||
def _collect_all_deps(target_name: str, config: ProjectConfig) -> list[str]:
|
||||
"""Collect all dependency names (inputs + reference_images + control_images)."""
|
||||
"""Collect all dependency names (inputs + reference_images + control_images + placeholders)."""
|
||||
target_cfg = config.targets[target_name]
|
||||
if not isinstance(target_cfg, GenerateTargetConfig):
|
||||
return []
|
||||
deps: list[str] = list(target_cfg.inputs)
|
||||
deps.extend(target_cfg.reference_images)
|
||||
deps.extend(target_cfg.control_images)
|
||||
deps.extend(extract_placeholder_files(target_cfg.prompt))
|
||||
return deps
|
||||
|
||||
|
||||
|
|
@ -162,7 +152,7 @@ async def _build_single_target(
|
|||
return
|
||||
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
||||
|
||||
provider = provider_index[model_info.name]
|
||||
await provider.generate(
|
||||
|
|
@ -276,7 +266,7 @@ def _is_dirty(
|
|||
)
|
||||
|
||||
model_info = resolve_model(target_name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
||||
dep_files = _collect_dep_files(target_name, config, project_dir)
|
||||
extra = _collect_extra_params(target_name, config)
|
||||
|
||||
|
|
@ -350,7 +340,7 @@ def _process_outcomes(
|
|||
model_name = _DOWNLOAD_MODEL_SENTINEL
|
||||
else:
|
||||
model_info = resolve_model(name, target_cfg, config.defaults)
|
||||
resolved_prompt = _resolve_prompt(target_cfg.prompt, project_dir)
|
||||
resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir)
|
||||
model_name = model_info.name
|
||||
|
||||
dep_files = _collect_dep_files(name, config, project_dir)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from pathlib import Path
|
|||
import networkx as nx
|
||||
|
||||
from hokusai.config import GenerateTargetConfig, ProjectConfig
|
||||
from hokusai.prompt import extract_placeholder_files
|
||||
|
||||
|
||||
def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
|
||||
|
|
@ -31,6 +32,7 @@ def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
|
|||
deps: list[str] = list(target_cfg.inputs)
|
||||
deps.extend(target_cfg.reference_images)
|
||||
deps.extend(target_cfg.control_images)
|
||||
deps.extend(extract_placeholder_files(target_cfg.prompt))
|
||||
|
||||
for dep in deps:
|
||||
if dep not in target_names and not (project_dir / dep).exists():
|
||||
|
|
|
|||
85
hokusai/prompt.py
Normal file
85
hokusai/prompt.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""Prompt resolution and placeholder substitution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"(\\*)\{([^}]+)\}")
|
||||
"""Match ``{filename}`` with optional leading backslashes.
|
||||
|
||||
Groups:
|
||||
1. Zero or more backslashes immediately before the ``{``
|
||||
2. The filename between the braces
|
||||
"""
|
||||
|
||||
|
||||
def _process_match(match: re.Match[str], project_dir: Path, *, substitute: bool) -> str:
|
||||
"""Process a single placeholder match.
|
||||
|
||||
When *substitute* is True the placeholder is replaced with the file
|
||||
contents; when False the original text is returned unchanged (used by
|
||||
:func:`extract_placeholder_files` to merely detect references).
|
||||
"""
|
||||
backslashes = match.group(1)
|
||||
filename = match.group(2)
|
||||
n_bs = len(backslashes)
|
||||
|
||||
# Odd number of backslashes → the brace is escaped.
|
||||
# Halve the backslashes (floor-division) and keep literal {filename}.
|
||||
if n_bs % 2 == 1:
|
||||
return "\\" * (n_bs // 2) + "{" + filename + "}"
|
||||
|
||||
# Even number (including zero) → substitute.
|
||||
prefix = "\\" * (n_bs // 2)
|
||||
if substitute:
|
||||
content = (project_dir / filename).read_text()
|
||||
return prefix + content
|
||||
return match.group(0)
|
||||
|
||||
|
||||
def substitute_placeholders(text: str, project_dir: Path) -> str:
|
||||
"""Replace ``{filename}`` placeholders with the referenced file contents.
|
||||
|
||||
Escaping rules:
|
||||
|
||||
* ``{file.txt}`` → contents of *file.txt*
|
||||
* ``\\{file.txt}`` → literal ``{file.txt}``
|
||||
* ``\\\\{file.txt}`` → literal ``\\`` + contents of *file.txt*
|
||||
"""
|
||||
return _PLACEHOLDER_RE.sub(
|
||||
lambda m: _process_match(m, project_dir, substitute=True), text
|
||||
)
|
||||
|
||||
|
||||
def extract_placeholder_files(prompt_value: str) -> list[str]:
|
||||
"""Return filenames referenced as ``{file}`` placeholders in *prompt_value*.
|
||||
|
||||
Only non-escaped placeholders are returned (even number of leading
|
||||
backslashes, including zero).
|
||||
"""
|
||||
files: list[str] = []
|
||||
for match in _PLACEHOLDER_RE.finditer(prompt_value):
|
||||
n_bs = len(match.group(1))
|
||||
if n_bs % 2 == 0:
|
||||
files.append(match.group(2))
|
||||
return files
|
||||
|
||||
|
||||
def resolve_prompt(prompt_value: str, project_dir: Path) -> str:
|
||||
"""Resolve a prompt string.
|
||||
|
||||
1. If *prompt_value* is a single-line string that refers to an existing
|
||||
file, read the file contents and return them (no further placeholder
|
||||
processing).
|
||||
2. Otherwise, substitute ``{filename}`` placeholders with file contents
|
||||
and return the result.
|
||||
"""
|
||||
if "\n" not in prompt_value:
|
||||
try:
|
||||
candidate = project_dir / prompt_value
|
||||
if candidate.is_file():
|
||||
return candidate.read_text()
|
||||
except OSError:
|
||||
pass
|
||||
return substitute_placeholders(prompt_value, project_dir)
|
||||
|
|
@ -13,7 +13,6 @@ from hokusai.builder import (
|
|||
_collect_all_deps, # pyright: ignore[reportPrivateUsage]
|
||||
_collect_dep_files, # pyright: ignore[reportPrivateUsage]
|
||||
_collect_extra_params, # pyright: ignore[reportPrivateUsage]
|
||||
_resolve_prompt, # pyright: ignore[reportPrivateUsage]
|
||||
run_build,
|
||||
)
|
||||
from hokusai.config import GenerateTargetConfig, ProjectConfig
|
||||
|
|
@ -112,21 +111,6 @@ def _fake_providers() -> list[Provider]:
|
|||
return [FakeTextProvider(), FakeImageProvider()]
|
||||
|
||||
|
||||
class TestResolvePrompt:
|
||||
"""Test prompt resolution (file vs inline)."""
|
||||
|
||||
def test_inline_prompt(self, project_dir: Path) -> None:
|
||||
assert _resolve_prompt("Just a string", project_dir) == "Just a string"
|
||||
|
||||
def test_file_prompt(self, project_dir: Path, prompt_file: Path) -> None:
|
||||
result = _resolve_prompt(prompt_file.name, project_dir)
|
||||
assert result == "This prompt comes from a file"
|
||||
|
||||
def test_nonexistent_file_treated_as_inline(self, project_dir: Path) -> None:
|
||||
result = _resolve_prompt("no_such_file.txt", project_dir)
|
||||
assert result == "no_such_file.txt"
|
||||
|
||||
|
||||
class TestCollectHelpers:
|
||||
"""Test dependency collection helpers."""
|
||||
|
||||
|
|
@ -568,3 +552,47 @@ class TestDownloadTarget:
|
|||
assert "description.txt" in result.built
|
||||
assert (project_dir / "fish.png").read_bytes() == b"fake fish image"
|
||||
assert (project_dir / "description.txt").exists()
|
||||
|
||||
|
||||
class TestPlaceholderPrompts:
|
||||
"""Tests for prompt placeholder substitution in builds."""
|
||||
|
||||
async def test_placeholder_in_prompt_triggers_rebuild(
|
||||
self, project_dir: Path, write_config: WriteConfig
|
||||
) -> None:
|
||||
_ = (project_dir / "style.txt").write_text("impressionist")
|
||||
config = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"out.txt": {"prompt": "Paint in {style.txt} style"},
|
||||
}
|
||||
}
|
||||
)
|
||||
with patch("hokusai.builder._create_providers", return_value=_fake_providers()):
|
||||
r1 = await run_build(config, project_dir, _PROJECT)
|
||||
assert r1.built == ["out.txt"]
|
||||
content1 = (project_dir / "out.txt").read_text()
|
||||
assert "impressionist" in content1
|
||||
|
||||
# Change the placeholder file
|
||||
_ = (project_dir / "style.txt").write_text("cubist")
|
||||
r2 = await run_build(config, project_dir, _PROJECT)
|
||||
assert r2.built == ["out.txt"]
|
||||
|
||||
async def test_placeholder_deps_in_collect_all(
|
||||
self, write_config: WriteConfig
|
||||
) -> None:
|
||||
config = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"out.txt": {
|
||||
"prompt": "Use {a.txt} and {b.txt}",
|
||||
"inputs": ["c.txt"],
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
deps = _collect_all_deps("out.txt", config)
|
||||
assert "a.txt" in deps
|
||||
assert "b.txt" in deps
|
||||
assert "c.txt" in deps
|
||||
|
|
|
|||
|
|
@ -194,3 +194,30 @@ class TestGetSubgraphForTarget:
|
|||
|
||||
assert sub.has_edge("input.txt", "summary.md")
|
||||
assert sub.has_edge("summary.md", "final.txt")
|
||||
|
||||
def test_placeholder_files_in_graph(
|
||||
self, project_dir: Path, write_config: WriteConfig
|
||||
) -> None:
|
||||
_ = (project_dir / "style.txt").write_text("impressionist")
|
||||
config = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"out.txt": {"prompt": "Paint in {style.txt} style"},
|
||||
}
|
||||
}
|
||||
)
|
||||
graph = build_graph(config, project_dir)
|
||||
assert graph.has_edge("style.txt", "out.txt")
|
||||
|
||||
def test_escaped_placeholder_not_in_graph(
|
||||
self, project_dir: Path, write_config: WriteConfig
|
||||
) -> None:
|
||||
config = write_config(
|
||||
{
|
||||
"targets": {
|
||||
"out.txt": {"prompt": "Literal \\{not_a_dep.txt}"},
|
||||
}
|
||||
}
|
||||
)
|
||||
graph = build_graph(config, project_dir)
|
||||
assert "not_a_dep.txt" not in graph.nodes
|
||||
|
|
|
|||
109
tests/test_prompt.py
Normal file
109
tests/test_prompt.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Tests for hokusai.prompt — prompt resolution and placeholder substitution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from hokusai.prompt import (
|
||||
extract_placeholder_files,
|
||||
resolve_prompt,
|
||||
substitute_placeholders,
|
||||
)
|
||||
|
||||
|
||||
class TestSubstitutePlaceholders:
|
||||
"""Test placeholder substitution in prompt text."""
|
||||
|
||||
def test_no_placeholders(self, project_dir: Path) -> None:
|
||||
assert substitute_placeholders("Hello world", project_dir) == "Hello world"
|
||||
|
||||
def test_single_placeholder(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "name.txt").write_text("World")
|
||||
result = substitute_placeholders("Hello {name.txt}", project_dir)
|
||||
assert result == "Hello World"
|
||||
|
||||
def test_multiple_placeholders(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "first.txt").write_text("Alice")
|
||||
_ = (project_dir / "last.txt").write_text("Smith")
|
||||
result = substitute_placeholders("Dear {first.txt} {last.txt}", project_dir)
|
||||
assert result == "Dear Alice Smith"
|
||||
|
||||
def test_escaped_placeholder(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "name.txt").write_text("World")
|
||||
result = substitute_placeholders("Hello \\{name.txt}", project_dir)
|
||||
assert result == "Hello {name.txt}"
|
||||
|
||||
def test_double_escaped_placeholder(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "name.txt").write_text("World")
|
||||
result = substitute_placeholders("Hello \\\\{name.txt}", project_dir)
|
||||
assert result == "Hello \\World"
|
||||
|
||||
def test_triple_escaped_placeholder(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "name.txt").write_text("World")
|
||||
result = substitute_placeholders("Hello \\\\\\{name.txt}", project_dir)
|
||||
assert result == "Hello \\{name.txt}"
|
||||
|
||||
def test_placeholder_at_start(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "greeting.txt").write_text("Hi there")
|
||||
result = substitute_placeholders("{greeting.txt}!", project_dir)
|
||||
assert result == "Hi there!"
|
||||
|
||||
def test_placeholder_with_multiline_content(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "block.txt").write_text("line1\nline2\nline3")
|
||||
result = substitute_placeholders("Start: {block.txt} :End", project_dir)
|
||||
assert result == "Start: line1\nline2\nline3 :End"
|
||||
|
||||
|
||||
class TestExtractPlaceholderFiles:
|
||||
"""Test extraction of placeholder filenames from prompts."""
|
||||
|
||||
def test_no_placeholders(self) -> None:
|
||||
assert extract_placeholder_files("Hello world") == []
|
||||
|
||||
def test_single_placeholder(self) -> None:
|
||||
assert extract_placeholder_files("Hello {name.txt}") == ["name.txt"]
|
||||
|
||||
def test_multiple_placeholders(self) -> None:
|
||||
result = extract_placeholder_files("{a.txt} and {b.txt}")
|
||||
assert result == ["a.txt", "b.txt"]
|
||||
|
||||
def test_escaped_placeholder_ignored(self) -> None:
|
||||
assert extract_placeholder_files("\\{name.txt}") == []
|
||||
|
||||
def test_double_escaped_included(self) -> None:
|
||||
assert extract_placeholder_files("\\\\{name.txt}") == ["name.txt"]
|
||||
|
||||
def test_triple_escaped_ignored(self) -> None:
|
||||
assert extract_placeholder_files("\\\\\\{name.txt}") == []
|
||||
|
||||
|
||||
class TestResolvePrompt:
|
||||
"""Test full prompt resolution (file loading + placeholder substitution)."""
|
||||
|
||||
def test_inline_prompt_no_placeholders(self, project_dir: Path) -> None:
|
||||
assert resolve_prompt("Just a string", project_dir) == "Just a string"
|
||||
|
||||
def test_file_prompt(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "prompt.txt").write_text("From file")
|
||||
result = resolve_prompt("prompt.txt", project_dir)
|
||||
assert result == "From file"
|
||||
|
||||
def test_nonexistent_file_treated_as_inline(self, project_dir: Path) -> None:
|
||||
result = resolve_prompt("no_such_file.txt", project_dir)
|
||||
assert result == "no_such_file.txt"
|
||||
|
||||
def test_inline_with_placeholder(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "style.txt").write_text("impressionist")
|
||||
result = resolve_prompt("Paint in {style.txt} style", project_dir)
|
||||
assert result == "Paint in impressionist style"
|
||||
|
||||
def test_multiline_prompt_with_placeholder(self, project_dir: Path) -> None:
|
||||
_ = (project_dir / "detail.txt").write_text("vivid colours")
|
||||
result = resolve_prompt("First line\nWith {detail.txt}", project_dir)
|
||||
assert result == "First line\nWith vivid colours"
|
||||
|
||||
def test_file_prompt_no_placeholder_processing(self, project_dir: Path) -> None:
|
||||
"""When the prompt is a file path, the file is loaded verbatim."""
|
||||
_ = (project_dir / "prompt.txt").write_text("Literal {not_a_file.txt}")
|
||||
result = resolve_prompt("prompt.txt", project_dir)
|
||||
assert result == "Literal {not_a_file.txt}"
|
||||
Loading…
Add table
Add a link
Reference in a new issue