"""Unit tests for hokusai.expand.""" from __future__ import annotations import pytest from hokusai.expand import ( expand_targets, extract_loop_variables, substitute_loop_variables, ) class TestExtractLoopVariables: """Tests for extracting [var] references from strings.""" def test_single_variable(self) -> None: assert extract_loop_variables("image-[a].png") == ["a"] def test_multiple_variables(self) -> None: assert extract_loop_variables("card-[size]-[color].png") == ["size", "color"] def test_no_variables(self) -> None: assert extract_loop_variables("plain.png") == [] def test_escaped_variable(self) -> None: assert extract_loop_variables(r"file-\[a].png") == [] def test_mixed_escaped_and_real(self) -> None: assert extract_loop_variables(r"file-\[a]-[b].png") == ["b"] def test_double_backslash_is_not_escaped(self) -> None: assert extract_loop_variables("file-\\\\[a].png") == ["a"] def test_deduplicates(self) -> None: assert extract_loop_variables("[a]-[a].png") == ["a"] def test_preserves_order(self) -> None: assert extract_loop_variables("[b]-[a]-[c].png") == ["b", "a", "c"] class TestSubstituteLoopVariables: """Tests for substituting [var] with values.""" def test_single_substitution(self) -> None: result = substitute_loop_variables("image-[a].png", {"a": "1"}) assert result == "image-1.png" def test_multiple_substitutions(self) -> None: result = substitute_loop_variables( "card-[size]-[color].png", {"size": "large", "color": "red"} ) assert result == "card-large-red.png" def test_escaped_not_substituted(self) -> None: result = substitute_loop_variables(r"file-\[a].png", {"a": "1"}) assert result == "file-[a].png" def test_double_backslash_substituted(self) -> None: result = substitute_loop_variables("file-\\\\[a].png", {"a": "1"}) assert result == "file-\\1.png" def test_unknown_variable_left_as_is(self) -> None: result = substitute_loop_variables("file-[unknown].png", {"a": "1"}) assert result == "file-[unknown].png" def test_no_variables(self) -> None: result = substitute_loop_variables("plain.png", {"a": "1"}) assert result == "plain.png" class TestExpandTargets: """Tests for full target expansion.""" def test_single_variable_expansion(self) -> None: raw: dict[str, object] = {"image-[a].png": {"prompt": "Draw [a]"}} loops = {"a": ["1", "2", "3"]} result = expand_targets(raw, loops) assert len(result) == 3 assert result["image-1.png"] == {"prompt": "Draw 1"} assert result["image-2.png"] == {"prompt": "Draw 2"} assert result["image-3.png"] == {"prompt": "Draw 3"} def test_cartesian_product(self) -> None: raw: dict[str, object] = {"card-[a]-[b].png": {"prompt": "[a] [b]"}} loops = {"a": ["1", "2"], "b": ["x", "y"]} result = expand_targets(raw, loops) assert len(result) == 4 assert result["card-1-x.png"] == {"prompt": "1 x"} assert result["card-1-y.png"] == {"prompt": "1 y"} assert result["card-2-x.png"] == {"prompt": "2 x"} assert result["card-2-y.png"] == {"prompt": "2 y"} def test_partial_loop_only_referenced_vars(self) -> None: raw: dict[str, object] = {"image-[a].png": {"prompt": "Draw [a]"}} loops = {"a": ["1", "2"], "b": ["x", "y"]} result = expand_targets(raw, loops) assert len(result) == 2 assert "image-1.png" in result assert "image-2.png" in result def test_non_template_target_passed_through(self) -> None: raw: dict[str, object] = { "image-[a].png": {"prompt": "Draw [a]"}, "static.txt": {"content": "hello"}, } loops = {"a": ["1", "2"]} result = expand_targets(raw, loops) assert len(result) == 3 assert result["static.txt"] == {"content": "hello"} def test_explicit_target_overrides_expanded(self) -> None: raw: dict[str, object] = { "image-[a].png": {"prompt": "Draw [a]"}, "image-1.png": {"prompt": "Custom prompt for 1"}, } loops = {"a": ["1", "2"]} result = expand_targets(raw, loops) assert len(result) == 2 assert result["image-1.png"] == {"prompt": "Custom prompt for 1"} assert result["image-2.png"] == {"prompt": "Draw 2"} def test_substitution_in_inputs(self) -> None: raw: dict[str, object] = { "out-[a].txt": { "prompt": "Summarize [a]", "inputs": ["data-[a].txt"], } } loops = {"a": ["x", "y"]} result = expand_targets(raw, loops) assert result["out-x.txt"] == { "prompt": "Summarize x", "inputs": ["data-x.txt"], } assert result["out-y.txt"] == { "prompt": "Summarize y", "inputs": ["data-y.txt"], } def test_substitution_in_reference_images(self) -> None: raw: dict[str, object] = { "out-[a].png": { "prompt": "Enhance", "reference_images": ["ref-[a].png"], } } loops = {"a": ["1", "2"]} result = expand_targets(raw, loops) assert result["out-1.png"]["reference_images"] == ["ref-1.png"] # pyright: ignore[reportIndexIssue] assert result["out-2.png"]["reference_images"] == ["ref-2.png"] # pyright: ignore[reportIndexIssue] def test_substitution_in_content(self) -> None: raw: dict[str, object] = {"file-[a].txt": {"content": "Value is [a]"}} loops = {"a": ["x", "y"]} result = expand_targets(raw, loops) assert result["file-x.txt"] == {"content": "Value is x"} assert result["file-y.txt"] == {"content": "Value is y"} def test_substitution_in_download(self) -> None: raw: dict[str, object] = { "file-[a].png": {"download": "https://example.com/[a].png"} } loops = {"a": ["cat", "dog"]} result = expand_targets(raw, loops) assert result["file-cat.png"] == {"download": "https://example.com/cat.png"} assert result["file-dog.png"] == {"download": "https://example.com/dog.png"} def test_escaped_brackets_preserved(self) -> None: raw: dict[str, object] = {r"image-[a].png": {"prompt": r"Draw \[a] for [a]"}} loops = {"a": ["1"]} result = expand_targets(raw, loops) assert result["image-1.png"] == {"prompt": "Draw [a] for 1"} def test_undefined_variable_raises(self) -> None: raw: dict[str, object] = {"image-[missing].png": {"prompt": "x"}} loops = {"a": ["1"]} with pytest.raises(ValueError, match="undefined loop variable"): _ = expand_targets(raw, loops) def test_duplicate_from_different_templates_raises(self) -> None: raw: dict[str, object] = { "[a]-[b].png": {"prompt": "first"}, "[b]-[a].png": {"prompt": "second"}, } loops = {"a": ["x"], "b": ["x"]} with pytest.raises(ValueError, match="Duplicate expanded target"): _ = expand_targets(raw, loops) def test_empty_loops_passes_through(self) -> None: raw: dict[str, object] = {"out.txt": {"prompt": "hello"}} result = expand_targets(raw, {}) assert result == {"out.txt": {"prompt": "hello"}} def test_cross_reference_between_expanded_targets(self) -> None: raw: dict[str, object] = { "data-[id].txt": {"content": "Data for [id]"}, "summary-[id].txt": { "prompt": "Summarize", "inputs": ["data-[id].txt"], }, } loops = {"id": ["a", "b"]} result = expand_targets(raw, loops) assert len(result) == 4 assert result["summary-a.txt"]["inputs"] == ["data-a.txt"] # pyright: ignore[reportIndexIssue] assert result["summary-b.txt"]["inputs"] == ["data-b.txt"] # pyright: ignore[reportIndexIssue] def test_substitution_in_control_images(self) -> None: raw: dict[str, object] = { "out-[a].png": { "prompt": "Generate", "control_images": ["ctrl-[a].png"], } } loops = {"a": ["1"]} result = expand_targets(raw, loops) assert result["out-1.png"]["control_images"] == ["ctrl-1.png"] # pyright: ignore[reportIndexIssue]