Source code for pytyche.bcf.diagnostics.topology

"""Per-tree topology hashing and aggregate mobility metrics for hurdle BCF.

Provides a vmap-friendly per-tree topology hash and helpers for computing
aggregate metrics (distinct counts, transition rates, run lengths, internal-
node counts, proposal accept rates) over a retained MCMC trace. The
companion ``TopologyHistory`` dataclass defines the per-iter retention
contract that the joint hurdle fit (``fit_hurdle_bcf(pooling="joint")``) populates when
``GPUBCFConfig.retain_topology_history=True``.

This module is purely additive observability — it has no effect on the
sampler or on PRNG consumption when retention is off.
"""

from __future__ import annotations

import dataclasses

import jax
import jax.numpy as jnp
import numpy as np

# Two odd 32-bit constants used as bases for the two independent polynomial
# hash lanes. P_LO is Knuth's golden-ratio mixing constant; P_HI is a
# distinct odd prime chosen so the two lanes are decorrelated. Combining
# them yields ~64 bits of entropy (collision prob ~2^-64 per pair).
P_LO = jnp.uint32(2654435761)
P_HI = jnp.uint32(1597334677)


[docs] def topology_hash_lanes( var_tree: jax.Array, split_tree: jax.Array, ) -> tuple[jax.Array, jax.Array]: """Compute two uint32 polynomial-hash lanes for one tree's structure. The canonical input is ``(var_tree, split_tree)`` — a bartz-encoded heap. Leaf positions (where ``split_tree[k] == 0``) carry stale ``var_tree`` values; the hash MUST be invariant to those, so the first step zeroes ``var_tree`` wherever ``split_tree == 0``. Two independent polynomial hashes are then run over the canonical packed ``(var * 256 + split)`` word stream, yielding two uint32 lanes (combined ~64 bits of entropy). JIT-friendly and vmap-friendly. Does NOT require ``jax_enable_x64``. Per-pair collision is ~2^-64; the birthday bound across N distinct trees is ~N^2 / 2^65 (negligible at our scales). Parameters ---------- var_tree : jax.Array ``(tree_size,)`` int32 array of split variables (bartz heap layout). split_tree : jax.Array ``(tree_size,)`` int32 array of split values; 0 marks a leaf. Returns ------- tuple[jax.Array, jax.Array] ``(hash_lo, hash_hi)`` — pair of scalar uint32 (vmap turns them into batched arrays). """ # Zero out stale var entries at leaf positions so hash depends only on # the canonical structural content (internal-node vars + splits). canonical_var = jnp.where(split_tree != 0, var_tree, 0) # Pack var (~8 bits) and split (~8 bits) into a single uint32 word. # The polynomial hash mixes them up to a full ~32 bits per lane. combined = ( canonical_var.astype(jnp.uint32) * jnp.uint32(256) + split_tree.astype(jnp.uint32) ) # Build the prime-power table at trace time from input length — adapts # to GPUBCFConfig.max_depth without module-load hardcoding. jnp.cumprod # folds into the kernel as a constant table per JIT trace. n = combined.shape[-1] ones = jnp.ones(n, dtype=jnp.uint32) powers_lo = jnp.cumprod(ones * P_LO) powers_hi = jnp.cumprod(ones * P_HI) # uint32 arithmetic wraps mod 2^32 in JAX — that's the desired # polynomial-hash behaviour, not overflow. hash_lo = jnp.sum(combined * powers_lo, dtype=jnp.uint32) hash_hi = jnp.sum(combined * powers_hi, dtype=jnp.uint32) return hash_lo, hash_hi
[docs] def combine_lanes( hash_lo: np.uint32 | np.ndarray, hash_hi: np.uint32 | np.ndarray, ) -> np.uint64 | np.ndarray: """Combine two uint32 hash lanes into a single uint64 (host-side). Returns ``(np.uint64(hi) << 32) | np.uint64(lo)``. Works on scalar uint32 inputs (returns ``np.uint64``) and on uint32 arrays (returns a ``np.ndarray`` of uint64 with the same shape). Operates on NumPy host-side; NumPy supports uint64 freely, while the JAX runtime here runs with x64 disabled. Parameters ---------- hash_lo : np.uint32 or np.ndarray of uint32 Low 32-bit lane. hash_hi : np.uint32 or np.ndarray of uint32 High 32-bit lane. Returns ------- np.uint64 or np.ndarray of uint64 Combined 64-bit hash. """ lo = np.asarray(hash_lo, dtype=np.uint32).astype(np.uint64) hi = np.asarray(hash_hi, dtype=np.uint32).astype(np.uint64) return (hi << np.uint64(32)) | lo
[docs] @dataclasses.dataclass(frozen=True) class TopologyHistory: """Per-iter retention of per-tree topology hashes + move metadata. Populated by ``fit_joint_hurdle_bcf`` when ``GPUBCFConfig.retain_topology_history=True``. All array fields share the leading shape ``(n_chains, n_iters_total, n_trees_total)``, where ``n_iters_total = num_burnin + num_mcmc`` and ``n_trees_total = num_trees_mu + num_trees_tau``. The first ``num_trees_mu`` columns of the trees axis are the prognostic (mu) forest; the remaining columns are the treatment (tau) forest. Aggregate metrics condition on ``iter >= burnin_iters`` to exclude transient burn-in behaviour. Attributes ---------- hashes : np.ndarray ``uint64``, shape ``(n_chains, n_iters_total, n_trees_total)``. Combined 64-bit topology hash per (chain, iter, tree). move_type : np.ndarray ``uint8``, same shape as ``hashes``. ``0`` = grow proposal, ``1`` = prune proposal; other values are tolerated but ignored by the proposal-rate metrics. accepted : np.ndarray ``bool_``, same shape as ``hashes``. Whether the proposed move was accepted at this (chain, iter, tree). internal_node_count : np.ndarray ``int32``, same shape as ``hashes``. Internal-node count of the tree after this iter's update (used for size summaries). burnin_iters : int Number of leading iters along ``axis=1`` to discard when computing aggregate metrics. num_trees_mu : int Number of leading columns along the trees axis that belong to the mu (prognostic) forest. Remaining columns are the tau (treatment) forest. """ hashes: np.ndarray move_type: np.ndarray accepted: np.ndarray internal_node_count: np.ndarray burnin_iters: int num_trees_mu: int
def _topology_metrics_for_slice( prefix: str, hs: np.ndarray, ms: np.ndarray, acs: np.ndarray, ns: np.ndarray, ) -> dict[str, float | int]: """Compute aggregate metrics for one (chain, iter, tree) sub-slice. Internal helper for :func:`compute_topology_metrics`. The four input arrays share shape ``(C, T, K)`` where ``T`` is the post-burn-in iter count and ``K`` is the tree count for this slice (full / mu-only / tau-only). Returns a flat dict keyed by ``f"{prefix}{name}"``. """ C, T, K = hs.shape if T == 0 or C == 0 or K == 0: # Degenerate slice — no post-burn-in data to summarize. Emit # sensible defaults so downstream aggregations don't divide by zero. return { f"{prefix}n_distinct_per_chain_tree_median": 1.0, f"{prefix}transition_rate_mean": 0.0, f"{prefix}mean_run_length": 1.0, f"{prefix}internal_node_count_mean": 0.0, f"{prefix}grow_accept_rate": float("nan"), f"{prefix}prune_accept_rate": float("nan"), f"{prefix}grow_proposal_count": 0, f"{prefix}grow_accepted_count": 0, f"{prefix}prune_proposal_count": 0, f"{prefix}prune_accepted_count": 0, } # n_distinct_per_chain_tree_median: median across (chain, tree) of the # number of distinct uint64 hashes seen along the iter axis. Aggregation # is once per fit, not in a hot loop — a Python double-loop is fine. distinct_counts = np.empty((C, K), dtype=np.int64) for c in range(C): for k in range(K): distinct_counts[c, k] = np.unique(hs[c, :, k]).size n_distinct_median = float(np.median(distinct_counts)) # transition_rate_mean: per (chain, tree) fraction of adjacent iters # where the hash flips. With T iters there are T-1 adjacent pairs. if T >= 2: diffs = (hs[:, 1:, :] != hs[:, :-1, :]).astype(np.float64) per_pair_rate = diffs.mean(axis=1) # (C, K) transition_rate_mean = float(per_pair_rate.mean()) n_transitions = (hs[:, 1:, :] != hs[:, :-1, :]).sum(axis=1) # (C, K) else: transition_rate_mean = 0.0 n_transitions = np.zeros((C, K), dtype=np.int64) # mean_run_length: with n_transitions flips along the iter axis there # are n_transitions + 1 constant-hash runs whose total length is T, # so the mean run length per (chain, tree) is exactly T / (flips + 1). n_runs = n_transitions + 1 mean_run_length = float((T / n_runs).mean()) internal_node_count_mean = float(ns.mean()) # Proposal counts + accept rates split by move_type. Other move_type # values (e.g., 2 = "no proposal") are intentionally ignored — they # contribute to neither numerator nor denominator. grow_mask = ms == 0 prune_mask = ms == 1 grow_proposal_count = int(grow_mask.sum()) grow_accepted_count = int((grow_mask & acs).sum()) prune_proposal_count = int(prune_mask.sum()) prune_accepted_count = int((prune_mask & acs).sum()) grow_accept_rate = ( float(grow_accepted_count / grow_proposal_count) if grow_proposal_count > 0 else float("nan") ) prune_accept_rate = ( float(prune_accepted_count / prune_proposal_count) if prune_proposal_count > 0 else float("nan") ) return { f"{prefix}n_distinct_per_chain_tree_median": n_distinct_median, f"{prefix}transition_rate_mean": transition_rate_mean, f"{prefix}mean_run_length": mean_run_length, f"{prefix}internal_node_count_mean": internal_node_count_mean, f"{prefix}grow_accept_rate": grow_accept_rate, f"{prefix}prune_accept_rate": prune_accept_rate, f"{prefix}grow_proposal_count": grow_proposal_count, f"{prefix}grow_accepted_count": grow_accepted_count, f"{prefix}prune_proposal_count": prune_proposal_count, f"{prefix}prune_accepted_count": prune_accepted_count, }
[docs] def compute_topology_metrics(history: TopologyHistory) -> dict[str, float | int]: """Aggregate per-tree topology mobility / proposal-rate metrics. For each base metric, emits three variants: * pooled — ``topology.<name>`` (all trees), * mu-only — ``topology.mu.<name>`` (first ``num_trees_mu`` columns), * tau-only — ``topology.tau.<name>`` (remaining columns). Base metrics: * ``n_distinct_per_chain_tree_median`` — median across (chain, tree) of the distinct-hash count along the iter axis. * ``transition_rate_mean`` — mean fraction of adjacent iters whose hash flips, averaged across (chain, tree). * ``mean_run_length`` — mean length of constant-hash runs per (chain, tree); equals ``T / (n_transitions + 1)``. * ``internal_node_count_mean`` — mean internal-node count across all (chain, iter, tree) entries in the slice. * ``grow_accept_rate`` / ``prune_accept_rate`` — accepted / proposed. ``float('nan')`` when the slice contains no proposals of that type. * ``grow_proposal_count`` / ``grow_accepted_count`` / ``prune_proposal_count`` / ``prune_accepted_count`` — raw int counts. All metrics condition on ``iter >= history.burnin_iters`` to exclude burn-in transients. Parameters ---------- history : TopologyHistory Retained per-iter topology trace produced by ``fit_joint_hurdle_bcf``. Returns ------- dict[str, float | int] Flat dict with the 30 keys described above (10 base metrics x 3 slice prefixes). """ # Strip burn-in along the iter axis (axis=1). h = history.hashes[:, history.burnin_iters :, :] m = history.move_type[:, history.burnin_iters :, :] a = history.accepted[:, history.burnin_iters :, :] n = history.internal_node_count[:, history.burnin_iters :, :] mu = history.num_trees_mu out: dict[str, float | int] = {} out.update(_topology_metrics_for_slice("topology.", h, m, a, n)) out.update( _topology_metrics_for_slice( "topology.mu.", h[:, :, :mu], m[:, :, :mu], a[:, :, :mu], n[:, :, :mu] ) ) out.update( _topology_metrics_for_slice( "topology.tau.", h[:, :, mu:], m[:, :, mu:], a[:, :, mu:], n[:, :, mu:] ) ) return out