Source code for pytyche.experiments.manifest

"""Versioned experiment manifest builder, validator, and atomic writer.

The canonical JSON Schema document lives at
``docs/specs/experiment-manifest-schema.json``. Validation is hand-rolled
rather than via ``jsonschema`` so error messages can name the offending
field directly, and so this module imports cleanly without heavy ML deps.
"""

from __future__ import annotations

import json
import os
import platform
import subprocess
import sys
import tempfile
from datetime import UTC, datetime
from pathlib import Path
from typing import Any

#: Current schema version. Bump for breaking changes; keep in sync with
#: ``docs/specs/experiment-manifest-schema.json`` ($id v1 → version 1).
MANIFEST_SCHEMA_VERSION: int = 1

#: Top-level fields every manifest must carry.
_REQUIRED_TOP_LEVEL: frozenset[str] = frozenset(
    {
        "manifest_schema_version",
        "experiment_id",
        "timestamp_utc",
        "git",
        "env",
        "params",
        "data_provenance",
    }
)

#: Reserved top-level key for per-capability extension content (e.g.
#: ``pytyche.calibration``). Foreign top-level keys outside the required set
#: and not equal to this name are rejected by ``validate_manifest``.
_RESERVED_EXTENSION_KEY: str = "pytyche"


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] def build_manifest( *, experiment_id: str, params: dict[str, Any], pytyche_extensions: dict[str, Any], data_provenance: dict[str, Any], ) -> dict[str, Any]: """Build a manifest dict with required top-level fields populated. ``git`` (sha / dirty / branch), ``env`` (python / platform), and ``timestamp_utc`` are resolved at call time from the ambient environment (``git`` via subprocess against ``Path.cwd()``; ``env`` via ``sys`` and ``platform``; ``timestamp_utc`` via ``datetime.now(timezone.utc)``). ``pytyche_extensions`` is a ``{capability: content}`` dict; each entry is nested under the reserved top-level ``pytyche`` key, e.g. ``pytyche_extensions={"calibration": {...}}`` lands at ``manifest["pytyche"]["calibration"]``. An empty extensions dict still produces an empty ``manifest["pytyche"]`` object so consumers can rely on the key existing. The returned dict is intentionally not validated here — call :func:`validate_manifest` separately if you want to assert correctness before writing. Parameters ---------- experiment_id: Unique identifier, conventionally ``{iso8601}_{short_sha}``. params: Free-form per-experiment hyperparameters. pytyche_extensions: Mapping from capability name (e.g. ``"calibration"``) to its nested content object. Goes under the reserved ``pytyche`` top-level key. data_provenance: Discriminated-union dict; either ``{"kind": "synthetic", "seed": int}`` or ``{"kind": "external", "hashes": {name: sha256_hex}}``. Returns ------- dict The constructed manifest. Caller owns persistence. """ return { "manifest_schema_version": MANIFEST_SCHEMA_VERSION, "experiment_id": experiment_id, "timestamp_utc": _now_iso_utc(), "git": _resolve_git_state(), "env": {"python": sys.version, "platform": platform.platform()}, "params": dict(params), "data_provenance": dict(data_provenance), _RESERVED_EXTENSION_KEY: dict(pytyche_extensions), }
[docs] def validate_manifest(manifest: object) -> None: """Validate a manifest; raise ``ValueError`` on any violation. Accepts ``object`` rather than ``dict`` because callers feed this from ``json.load()`` of untrusted files — the runtime value may be a list, ``None``, a string, etc. The isinstance narrowing below is the boundary check; downstream code only runs once we know we have a dict. Checks performed (each failure raises with an error message that names the offending field(s) — no silent catch-all messages): 1. Top-level value is a dict. 2. Required top-level fields are present (names listed in the error). 3. No foreign top-level keys (anything not required and not ``pytyche``). 4. ``data_provenance`` is a dict with a valid ``kind`` discriminator, and for ``kind == "external"`` the ``hashes`` map is non-empty. Passes silently when valid. """ if not isinstance(manifest, dict): raise ValueError( f"manifest must be a dict, got {type(manifest).__name__}" ) # (1) Missing required fields. Sorted for deterministic error messages. missing = sorted(_REQUIRED_TOP_LEVEL - manifest.keys()) if missing: raise ValueError( "manifest missing required fields: " + ", ".join(missing) ) # (2) Foreign top-level keys (not required, not the reserved extension key). allowed = _REQUIRED_TOP_LEVEL | {_RESERVED_EXTENSION_KEY} foreign = sorted(set(manifest.keys()) - allowed) if foreign: raise ValueError( "manifest has unexpected top-level keys (foreign keys not " "permitted; extension content must live under the reserved " "'pytyche' key): " + ", ".join(foreign) ) # (3) data_provenance discriminator. _validate_data_provenance(manifest["data_provenance"])
[docs] def write_manifest(manifest: dict[str, Any], path: str | Path) -> None: """Atomically write ``manifest`` as JSON to ``path``. Uses a temporary file in the same directory followed by ``os.replace`` so partial writes are never visible at ``path`` (POSIX-atomic rename). The temporary file is cleaned up on failure. Parameters ---------- manifest: The manifest dict. Caller is responsible for validating it first if desired — this function only persists. path: Destination path. The parent directory must already exist. """ final_path = Path(path) parent = final_path.parent # NamedTemporaryFile gives us cleanup-on-error semantics; delete=False # because we hand the file to os.replace after closing it. fd, tmp_name = tempfile.mkstemp( prefix=final_path.name + ".", suffix=".tmp", dir=str(parent), ) tmp_path = Path(tmp_name) try: with os.fdopen(fd, "w", encoding="utf-8") as handle: json.dump(manifest, handle, indent=2, sort_keys=True) handle.write("\n") handle.flush() os.fsync(handle.fileno()) os.replace(tmp_path, final_path) except BaseException: # Best-effort cleanup; suppress secondary errors so the original # exception surfaces. try: tmp_path.unlink() except FileNotFoundError: pass raise
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _now_iso_utc() -> str: """Return the current UTC time as an ISO 8601 string.""" return datetime.now(UTC).isoformat() def _resolve_git_state() -> dict[str, Any]: """Resolve ``{sha, dirty, branch}`` from git at ``Path.cwd()``. Uses ``subprocess.run`` with ``check=True`` — failures surface loudly rather than producing a partially-populated git object. The caller is expected to invoke this from inside a git working tree. """ cwd = Path.cwd() def _git(*args: str) -> str: return subprocess.run( ["git", *args], cwd=cwd, check=True, capture_output=True, text=True, ).stdout.strip() sha = _git("rev-parse", "HEAD") branch = _git("rev-parse", "--abbrev-ref", "HEAD") porcelain = _git("status", "--porcelain") return {"sha": sha, "dirty": bool(porcelain), "branch": branch} def _validate_data_provenance(provenance: Any) -> None: """Validate the ``data_provenance`` discriminated union. Raises ``ValueError`` whose message names ``data_provenance`` so the spec's "validator names the malformed field" contract is satisfied. """ if not isinstance(provenance, dict): raise ValueError( "data_provenance must be a dict with a 'kind' discriminator; " f"got {type(provenance).__name__}" ) kind = provenance.get("kind") if kind is None: raise ValueError( "data_provenance missing required 'kind' discriminator " "(expected one of: 'synthetic', 'external')" ) if kind == "synthetic": if "seed" not in provenance: raise ValueError( "data_provenance kind='synthetic' requires a 'seed' field" ) if not isinstance(provenance["seed"], int) or isinstance( provenance["seed"], bool ): raise ValueError( "data_provenance.seed must be an int for kind='synthetic'" ) elif kind == "external": hashes = provenance.get("hashes") if not isinstance(hashes, dict): raise ValueError( "data_provenance kind='external' requires a 'hashes' dict" ) if not hashes: raise ValueError( "data_provenance kind='external' requires a non-empty " "'hashes' map (at least one artifact hash)" ) else: raise ValueError( f"data_provenance has unknown 'kind' value {kind!r}; " "expected 'synthetic' or 'external'" )