"""Per-tree topology hashing and aggregate mobility metrics for hurdle BCF.
Provides a vmap-friendly per-tree topology hash and helpers for computing
aggregate metrics (distinct counts, transition rates, run lengths, internal-
node counts, proposal accept rates) over a retained MCMC trace. The
companion ``TopologyHistory`` dataclass defines the per-iter retention
contract that the joint hurdle fit (``fit_hurdle_bcf(pooling="joint")``) populates when
``GPUBCFConfig.retain_topology_history=True``.
This module is purely additive observability — it has no effect on the
sampler or on PRNG consumption when retention is off.
"""
from __future__ import annotations
import dataclasses
import jax
import jax.numpy as jnp
import numpy as np
# Two odd 32-bit constants used as bases for the two independent polynomial
# hash lanes. P_LO is Knuth's golden-ratio mixing constant; P_HI is a
# distinct odd prime chosen so the two lanes are decorrelated. Combining
# them yields ~64 bits of entropy (collision prob ~2^-64 per pair).
P_LO = jnp.uint32(2654435761)
P_HI = jnp.uint32(1597334677)
[docs]
def topology_hash_lanes(
var_tree: jax.Array,
split_tree: jax.Array,
) -> tuple[jax.Array, jax.Array]:
"""Compute two uint32 polynomial-hash lanes for one tree's structure.
The canonical input is ``(var_tree, split_tree)`` — a bartz-encoded heap.
Leaf positions (where ``split_tree[k] == 0``) carry stale ``var_tree``
values; the hash MUST be invariant to those, so the first step zeroes
``var_tree`` wherever ``split_tree == 0``. Two independent polynomial
hashes are then run over the canonical packed ``(var * 256 + split)``
word stream, yielding two uint32 lanes (combined ~64 bits of entropy).
JIT-friendly and vmap-friendly. Does NOT require ``jax_enable_x64``.
Per-pair collision is ~2^-64; the birthday bound across N distinct
trees is ~N^2 / 2^65 (negligible at our scales).
Parameters
----------
var_tree : jax.Array
``(tree_size,)`` int32 array of split variables (bartz heap layout).
split_tree : jax.Array
``(tree_size,)`` int32 array of split values; 0 marks a leaf.
Returns
-------
tuple[jax.Array, jax.Array]
``(hash_lo, hash_hi)`` — pair of scalar uint32 (vmap turns them
into batched arrays).
"""
# Zero out stale var entries at leaf positions so hash depends only on
# the canonical structural content (internal-node vars + splits).
canonical_var = jnp.where(split_tree != 0, var_tree, 0)
# Pack var (~8 bits) and split (~8 bits) into a single uint32 word.
# The polynomial hash mixes them up to a full ~32 bits per lane.
combined = (
canonical_var.astype(jnp.uint32) * jnp.uint32(256)
+ split_tree.astype(jnp.uint32)
)
# Build the prime-power table at trace time from input length — adapts
# to GPUBCFConfig.max_depth without module-load hardcoding. jnp.cumprod
# folds into the kernel as a constant table per JIT trace.
n = combined.shape[-1]
ones = jnp.ones(n, dtype=jnp.uint32)
powers_lo = jnp.cumprod(ones * P_LO)
powers_hi = jnp.cumprod(ones * P_HI)
# uint32 arithmetic wraps mod 2^32 in JAX — that's the desired
# polynomial-hash behaviour, not overflow.
hash_lo = jnp.sum(combined * powers_lo, dtype=jnp.uint32)
hash_hi = jnp.sum(combined * powers_hi, dtype=jnp.uint32)
return hash_lo, hash_hi
[docs]
def combine_lanes(
hash_lo: np.uint32 | np.ndarray,
hash_hi: np.uint32 | np.ndarray,
) -> np.uint64 | np.ndarray:
"""Combine two uint32 hash lanes into a single uint64 (host-side).
Returns ``(np.uint64(hi) << 32) | np.uint64(lo)``. Works on scalar
uint32 inputs (returns ``np.uint64``) and on uint32 arrays (returns a
``np.ndarray`` of uint64 with the same shape).
Operates on NumPy host-side; NumPy supports uint64 freely, while the
JAX runtime here runs with x64 disabled.
Parameters
----------
hash_lo : np.uint32 or np.ndarray of uint32
Low 32-bit lane.
hash_hi : np.uint32 or np.ndarray of uint32
High 32-bit lane.
Returns
-------
np.uint64 or np.ndarray of uint64
Combined 64-bit hash.
"""
lo = np.asarray(hash_lo, dtype=np.uint32).astype(np.uint64)
hi = np.asarray(hash_hi, dtype=np.uint32).astype(np.uint64)
return (hi << np.uint64(32)) | lo
[docs]
@dataclasses.dataclass(frozen=True)
class TopologyHistory:
"""Per-iter retention of per-tree topology hashes + move metadata.
Populated by ``fit_joint_hurdle_bcf`` when ``GPUBCFConfig.retain_topology_history=True``.
All array fields share the leading shape ``(n_chains, n_iters_total, n_trees_total)``,
where ``n_iters_total = num_burnin + num_mcmc`` and
``n_trees_total = num_trees_mu + num_trees_tau``. The first
``num_trees_mu`` columns of the trees axis are the prognostic (mu) forest;
the remaining columns are the treatment (tau) forest.
Aggregate metrics condition on ``iter >= burnin_iters`` to exclude
transient burn-in behaviour.
Attributes
----------
hashes : np.ndarray
``uint64``, shape ``(n_chains, n_iters_total, n_trees_total)``. Combined
64-bit topology hash per (chain, iter, tree).
move_type : np.ndarray
``uint8``, same shape as ``hashes``. ``0`` = grow proposal, ``1`` =
prune proposal; other values are tolerated but ignored by the
proposal-rate metrics.
accepted : np.ndarray
``bool_``, same shape as ``hashes``. Whether the proposed move was
accepted at this (chain, iter, tree).
internal_node_count : np.ndarray
``int32``, same shape as ``hashes``. Internal-node count of the tree
after this iter's update (used for size summaries).
burnin_iters : int
Number of leading iters along ``axis=1`` to discard when computing
aggregate metrics.
num_trees_mu : int
Number of leading columns along the trees axis that belong to the
mu (prognostic) forest. Remaining columns are the tau (treatment)
forest.
"""
hashes: np.ndarray
move_type: np.ndarray
accepted: np.ndarray
internal_node_count: np.ndarray
burnin_iters: int
num_trees_mu: int
def _topology_metrics_for_slice(
prefix: str,
hs: np.ndarray,
ms: np.ndarray,
acs: np.ndarray,
ns: np.ndarray,
) -> dict[str, float | int]:
"""Compute aggregate metrics for one (chain, iter, tree) sub-slice.
Internal helper for :func:`compute_topology_metrics`. The four input
arrays share shape ``(C, T, K)`` where ``T`` is the post-burn-in iter
count and ``K`` is the tree count for this slice (full / mu-only /
tau-only). Returns a flat dict keyed by ``f"{prefix}{name}"``.
"""
C, T, K = hs.shape
if T == 0 or C == 0 or K == 0:
# Degenerate slice — no post-burn-in data to summarize. Emit
# sensible defaults so downstream aggregations don't divide by zero.
return {
f"{prefix}n_distinct_per_chain_tree_median": 1.0,
f"{prefix}transition_rate_mean": 0.0,
f"{prefix}mean_run_length": 1.0,
f"{prefix}internal_node_count_mean": 0.0,
f"{prefix}grow_accept_rate": float("nan"),
f"{prefix}prune_accept_rate": float("nan"),
f"{prefix}grow_proposal_count": 0,
f"{prefix}grow_accepted_count": 0,
f"{prefix}prune_proposal_count": 0,
f"{prefix}prune_accepted_count": 0,
}
# n_distinct_per_chain_tree_median: median across (chain, tree) of the
# number of distinct uint64 hashes seen along the iter axis. Aggregation
# is once per fit, not in a hot loop — a Python double-loop is fine.
distinct_counts = np.empty((C, K), dtype=np.int64)
for c in range(C):
for k in range(K):
distinct_counts[c, k] = np.unique(hs[c, :, k]).size
n_distinct_median = float(np.median(distinct_counts))
# transition_rate_mean: per (chain, tree) fraction of adjacent iters
# where the hash flips. With T iters there are T-1 adjacent pairs.
if T >= 2:
diffs = (hs[:, 1:, :] != hs[:, :-1, :]).astype(np.float64)
per_pair_rate = diffs.mean(axis=1) # (C, K)
transition_rate_mean = float(per_pair_rate.mean())
n_transitions = (hs[:, 1:, :] != hs[:, :-1, :]).sum(axis=1) # (C, K)
else:
transition_rate_mean = 0.0
n_transitions = np.zeros((C, K), dtype=np.int64)
# mean_run_length: with n_transitions flips along the iter axis there
# are n_transitions + 1 constant-hash runs whose total length is T,
# so the mean run length per (chain, tree) is exactly T / (flips + 1).
n_runs = n_transitions + 1
mean_run_length = float((T / n_runs).mean())
internal_node_count_mean = float(ns.mean())
# Proposal counts + accept rates split by move_type. Other move_type
# values (e.g., 2 = "no proposal") are intentionally ignored — they
# contribute to neither numerator nor denominator.
grow_mask = ms == 0
prune_mask = ms == 1
grow_proposal_count = int(grow_mask.sum())
grow_accepted_count = int((grow_mask & acs).sum())
prune_proposal_count = int(prune_mask.sum())
prune_accepted_count = int((prune_mask & acs).sum())
grow_accept_rate = (
float(grow_accepted_count / grow_proposal_count)
if grow_proposal_count > 0
else float("nan")
)
prune_accept_rate = (
float(prune_accepted_count / prune_proposal_count)
if prune_proposal_count > 0
else float("nan")
)
return {
f"{prefix}n_distinct_per_chain_tree_median": n_distinct_median,
f"{prefix}transition_rate_mean": transition_rate_mean,
f"{prefix}mean_run_length": mean_run_length,
f"{prefix}internal_node_count_mean": internal_node_count_mean,
f"{prefix}grow_accept_rate": grow_accept_rate,
f"{prefix}prune_accept_rate": prune_accept_rate,
f"{prefix}grow_proposal_count": grow_proposal_count,
f"{prefix}grow_accepted_count": grow_accepted_count,
f"{prefix}prune_proposal_count": prune_proposal_count,
f"{prefix}prune_accepted_count": prune_accepted_count,
}
[docs]
def compute_topology_metrics(history: TopologyHistory) -> dict[str, float | int]:
"""Aggregate per-tree topology mobility / proposal-rate metrics.
For each base metric, emits three variants:
* pooled — ``topology.<name>`` (all trees),
* mu-only — ``topology.mu.<name>`` (first ``num_trees_mu`` columns),
* tau-only — ``topology.tau.<name>`` (remaining columns).
Base metrics:
* ``n_distinct_per_chain_tree_median`` — median across (chain, tree)
of the distinct-hash count along the iter axis.
* ``transition_rate_mean`` — mean fraction of adjacent iters whose
hash flips, averaged across (chain, tree).
* ``mean_run_length`` — mean length of constant-hash runs per
(chain, tree); equals ``T / (n_transitions + 1)``.
* ``internal_node_count_mean`` — mean internal-node count across all
(chain, iter, tree) entries in the slice.
* ``grow_accept_rate`` / ``prune_accept_rate`` — accepted / proposed.
``float('nan')`` when the slice contains no proposals of that type.
* ``grow_proposal_count`` / ``grow_accepted_count`` /
``prune_proposal_count`` / ``prune_accepted_count`` — raw int counts.
All metrics condition on ``iter >= history.burnin_iters`` to exclude
burn-in transients.
Parameters
----------
history : TopologyHistory
Retained per-iter topology trace produced by ``fit_joint_hurdle_bcf``.
Returns
-------
dict[str, float | int]
Flat dict with the 30 keys described above (10 base metrics x 3
slice prefixes).
"""
# Strip burn-in along the iter axis (axis=1).
h = history.hashes[:, history.burnin_iters :, :]
m = history.move_type[:, history.burnin_iters :, :]
a = history.accepted[:, history.burnin_iters :, :]
n = history.internal_node_count[:, history.burnin_iters :, :]
mu = history.num_trees_mu
out: dict[str, float | int] = {}
out.update(_topology_metrics_for_slice("topology.", h, m, a, n))
out.update(
_topology_metrics_for_slice(
"topology.mu.", h[:, :, :mu], m[:, :, :mu], a[:, :, :mu], n[:, :, :mu]
)
)
out.update(
_topology_metrics_for_slice(
"topology.tau.", h[:, :, mu:], m[:, :, mu:], a[:, :, mu:], n[:, :, mu:]
)
)
return out