hokusai/hokusai/prompt.py

85 lines
2.7 KiB
Python

"""Prompt resolution and placeholder substitution."""
from __future__ import annotations
import re
from pathlib import Path
_PLACEHOLDER_RE = re.compile(r"(\\*)\{([^}]+)\}")
"""Match ``{filename}`` with optional leading backslashes.
Groups:
1. Zero or more backslashes immediately before the ``{``
2. The filename between the braces
"""
def _process_match(match: re.Match[str], project_dir: Path, *, substitute: bool) -> str:
"""Process a single placeholder match.
When *substitute* is True the placeholder is replaced with the file
contents; when False the original text is returned unchanged (used by
:func:`extract_placeholder_files` to merely detect references).
"""
backslashes = match.group(1)
filename = match.group(2)
n_bs = len(backslashes)
# Odd number of backslashes → the brace is escaped.
# Halve the backslashes (floor-division) and keep literal {filename}.
if n_bs % 2 == 1:
return "\\" * (n_bs // 2) + "{" + filename + "}"
# Even number (including zero) → substitute.
prefix = "\\" * (n_bs // 2)
if substitute:
content = (project_dir / filename).read_text()
return prefix + content
return match.group(0)
def substitute_placeholders(text: str, project_dir: Path) -> str:
"""Replace ``{filename}`` placeholders with the referenced file contents.
Escaping rules:
* ``{file.txt}`` → contents of *file.txt*
* ``\\{file.txt}`` → literal ``{file.txt}``
* ``\\\\{file.txt}`` → literal ``\\`` + contents of *file.txt*
"""
return _PLACEHOLDER_RE.sub(
lambda m: _process_match(m, project_dir, substitute=True), text
)
def extract_placeholder_files(prompt_value: str) -> list[str]:
"""Return filenames referenced as ``{file}`` placeholders in *prompt_value*.
Only non-escaped placeholders are returned (even number of leading
backslashes, including zero).
"""
files: list[str] = []
for match in _PLACEHOLDER_RE.finditer(prompt_value):
n_bs = len(match.group(1))
if n_bs % 2 == 0:
files.append(match.group(2))
return files
def resolve_prompt(prompt_value: str, project_dir: Path) -> str:
"""Resolve a prompt string.
1. If *prompt_value* is a single-line string that refers to an existing
file, read the file contents and return them (no further placeholder
processing).
2. Otherwise, substitute ``{filename}`` placeholders with file contents
and return the result.
"""
if "\n" not in prompt_value:
try:
candidate = project_dir / prompt_value
if candidate.is_file():
return candidate.read_text()
except OSError:
pass
return substitute_placeholders(prompt_value, project_dir)