hokusai/hokusai/graph.py
Konstantin Fickel 4def49350e
All checks were successful
Continuous Integration / Build Package (push) Successful in 35s
Continuous Integration / Lint, Check & Test (push) Successful in 57s
chore: rename bulkgen to hokusai
2026-02-20 17:08:12 +01:00

63 lines
2.2 KiB
Python

"""Dependency graph construction and traversal using networkx."""
from __future__ import annotations
from pathlib import Path
import networkx as nx
from hokusai.config import ProjectConfig
def build_graph(config: ProjectConfig, project_dir: Path) -> nx.DiGraph[str]:
"""Build a dependency DAG from the project configuration.
Nodes are filenames: target names (keys in ``config.targets``) and
external files that exist on disk. Edges point from dependency to
dependent (``A -> B`` means *A must exist before B*).
Raises :class:`ValueError` if a dependency is neither a defined target
nor an existing file, or if the graph contains a cycle.
"""
graph: nx.DiGraph[str] = nx.DiGraph()
target_names = set(config.targets)
for target_name, target_cfg in config.targets.items():
graph.add_node(target_name)
deps: list[str] = list(target_cfg.inputs)
deps.extend(target_cfg.reference_images)
deps.extend(target_cfg.control_images)
for dep in deps:
if dep not in target_names and not (project_dir / dep).exists():
msg = (
f"Target '{target_name}' depends on '{dep}', "
f"which is neither a defined target nor an existing file"
)
raise ValueError(msg)
_ = graph.add_edge(dep, target_name)
if not nx.is_directed_acyclic_graph(graph):
cycles = list(nx.simple_cycles(graph))
msg = f"Dependency cycle detected: {cycles}"
raise ValueError(msg)
return graph
def get_build_order(graph: nx.DiGraph[str]) -> list[list[str]]:
"""Return targets grouped into generations for parallel execution.
Each inner list contains nodes with no inter-dependencies that can
be built concurrently.
"""
return [list(gen) for gen in nx.topological_generations(graph)]
def get_subgraph_for_target(graph: nx.DiGraph[str], target: str) -> nx.DiGraph[str]:
"""Return the subgraph containing *target* and all its transitive dependencies."""
ancestors: set[str] = nx.ancestors(graph, target) # pyright: ignore[reportUnknownMemberType]
ancestors.add(target)
subgraph: nx.DiGraph[str] = nx.DiGraph(graph.subgraph(ancestors))
return subgraph