diff --git a/hokusai/builder.py b/hokusai/builder.py index 0a908b3..8a0f43b 100644 --- a/hokusai/builder.py +++ b/hokusai/builder.py @@ -252,9 +252,6 @@ def _should_skip_failed_dep( return any(d in result.failed for d in _collect_all_deps(target_name, config)) -_DOWNLOAD_MODEL_SENTINEL = "__download__" - - def _is_dirty( target_name: str, config: ProjectConfig, @@ -267,10 +264,7 @@ def _is_dirty( if isinstance(target_cfg, DownloadTargetConfig): return is_target_dirty( target_name, - resolved_prompt=target_cfg.download, - model=_DOWNLOAD_MODEL_SENTINEL, - dep_files=[], - extra_params={}, + download=target_cfg.download, state=state, project_dir=project_dir, ) @@ -346,24 +340,27 @@ def _process_outcomes( target_cfg = config.targets[name] if isinstance(target_cfg, DownloadTargetConfig): - resolved_prompt = target_cfg.download - model_name = _DOWNLOAD_MODEL_SENTINEL + record_target_state( + name, + download=target_cfg.download, + state=state, + project_dir=project_dir, + ) else: model_info = resolve_model(name, target_cfg, config.defaults) resolved_prompt = resolve_prompt(target_cfg.prompt, project_dir) - model_name = model_info.name + dep_files = _collect_dep_files(name, config, project_dir) + extra = _collect_extra_params(name, config) - dep_files = _collect_dep_files(name, config, project_dir) - extra = _collect_extra_params(name, config) + record_target_state( + name, + resolved_prompt=resolved_prompt, + model=model_info.name, + dep_files=dep_files, + extra_params=extra, + state=state, + project_dir=project_dir, + ) - record_target_state( - name, - resolved_prompt=resolved_prompt, - model=model_name, - dep_files=dep_files, - extra_params=extra, - state=state, - project_dir=project_dir, - ) result.built.append(name) on_progress(BuildEvent.TARGET_OK, name, "") diff --git a/hokusai/state.py b/hokusai/state.py index d10b955..8aee172 100644 --- a/hokusai/state.py +++ b/hokusai/state.py @@ -21,10 +21,11 @@ def state_filename(project_name: str) -> str: class TargetState(BaseModel): """Recorded state of a single target from its last successful build.""" - input_hashes: dict[str, str] - prompt: str - model: str + input_hashes: dict[str, str] = {} + prompt: str | None = None + model: str | None = None extra_params: dict[str, object] = {} + download: str | None = None class BuildState(BaseModel): @@ -58,16 +59,22 @@ def save_state(state: BuildState, project_dir: Path, project_name: str) -> None: """Persist build state to disk.""" state_path = project_dir / state_filename(project_name) with state_path.open("w") as f: - yaml.dump(state.model_dump(), f, default_flow_style=False, sort_keys=False) + yaml.dump( + state.model_dump(exclude_defaults=True), + f, + default_flow_style=False, + sort_keys=False, + ) def is_target_dirty( target_name: str, *, - resolved_prompt: str, - model: str, - dep_files: list[Path], - extra_params: dict[str, object], + resolved_prompt: str | None = None, + model: str | None = None, + dep_files: list[Path] | None = None, + extra_params: dict[str, object] | None = None, + download: str | None = None, state: BuildState, project_dir: Path, ) -> bool: @@ -76,10 +83,8 @@ def is_target_dirty( A target is dirty if: - Its output file does not exist - It has never been built (not recorded in state) - - Any dependency file hash has changed - - The resolved prompt text has changed - - The model has changed - - Extra parameters (width, height, etc.) have changed + - For download targets: the download URL has changed + - For generate targets: any dependency file hash, prompt, model, or extra params changed """ output_path = project_dir / target_name if not output_path.exists(): @@ -90,16 +95,21 @@ def is_target_dirty( prev = state.targets[target_name] + # Download targets only compare the URL. + if download is not None: + return prev.download != download + + # Generate targets compare prompt, model, extra params, and input hashes. if prev.model != model: return True if prev.prompt != resolved_prompt: return True - if prev.extra_params != extra_params: + if prev.extra_params != (extra_params or {}): return True - for dep_path in dep_files: + for dep_path in dep_files or []: dep_key = str(dep_path.relative_to(project_dir)) current_hash = hash_file(dep_path) if prev.input_hashes.get(dep_key) != current_hash: @@ -111,16 +121,21 @@ def is_target_dirty( def record_target_state( target_name: str, *, - resolved_prompt: str, - model: str, - dep_files: list[Path], - extra_params: dict[str, object], + resolved_prompt: str | None = None, + model: str | None = None, + dep_files: list[Path] | None = None, + extra_params: dict[str, object] | None = None, + download: str | None = None, state: BuildState, project_dir: Path, ) -> None: """Record the state of a successfully built target.""" + if download is not None: + state.targets[target_name] = TargetState(download=download) + return + input_hashes: dict[str, str] = {} - for dep_path in dep_files: + for dep_path in dep_files or []: dep_key = str(dep_path.relative_to(project_dir)) input_hashes[dep_key] = hash_file(dep_path) @@ -128,5 +143,5 @@ def record_target_state( input_hashes=input_hashes, prompt=resolved_prompt, model=model, - extra_params=extra_params, + extra_params=extra_params or {}, )