pytyche.bcf.diagnostics.topology

Per-tree topology hashing and aggregate mobility metrics for hurdle BCF.

Provides a vmap-friendly per-tree topology hash and helpers for computing aggregate metrics (distinct counts, transition rates, run lengths, internal- node counts, proposal accept rates) over a retained MCMC trace. The companion TopologyHistory dataclass defines the per-iter retention contract that the joint hurdle fit (fit_hurdle_bcf(pooling="joint")) populates when GPUBCFConfig.retain_topology_history=True.

This module is purely additive observability — it has no effect on the sampler or on PRNG consumption when retention is off.

Functions

combine_lanes(hash_lo, hash_hi)

Combine two uint32 hash lanes into a single uint64 (host-side).

compute_topology_metrics(history)

Aggregate per-tree topology mobility / proposal-rate metrics.

topology_hash_lanes(var_tree, split_tree)

Compute two uint32 polynomial-hash lanes for one tree's structure.

Classes

TopologyHistory(hashes, move_type, accepted, ...)

Per-iter retention of per-tree topology hashes + move metadata.

pytyche.bcf.diagnostics.topology.topology_hash_lanes(var_tree, split_tree)[source]

Compute two uint32 polynomial-hash lanes for one tree’s structure.

The canonical input is (var_tree, split_tree) — a bartz-encoded heap. Leaf positions (where split_tree[k] == 0) carry stale var_tree values; the hash MUST be invariant to those, so the first step zeroes var_tree wherever split_tree == 0. Two independent polynomial hashes are then run over the canonical packed (var * 256 + split) word stream, yielding two uint32 lanes (combined ~64 bits of entropy).

JIT-friendly and vmap-friendly. Does NOT require jax_enable_x64.

Per-pair collision is ~2^-64; the birthday bound across N distinct trees is ~N^2 / 2^65 (negligible at our scales).

Parameters:
  • var_tree (Array) – (tree_size,) int32 array of split variables (bartz heap layout).

  • split_tree (Array) – (tree_size,) int32 array of split values; 0 marks a leaf.

Returns:

(hash_lo, hash_hi) — pair of scalar uint32 (vmap turns them into batched arrays).

Return type:

tuple[Array, Array]

pytyche.bcf.diagnostics.topology.combine_lanes(hash_lo, hash_hi)[source]

Combine two uint32 hash lanes into a single uint64 (host-side).

Returns (np.uint64(hi) << 32) | np.uint64(lo). Works on scalar uint32 inputs (returns np.uint64) and on uint32 arrays (returns a np.ndarray of uint64 with the same shape).

Operates on NumPy host-side; NumPy supports uint64 freely, while the JAX runtime here runs with x64 disabled.

Parameters:
  • hash_lo (uint32 | ndarray) – Low 32-bit lane.

  • hash_hi (uint32 | ndarray) – High 32-bit lane.

Returns:

Combined 64-bit hash.

Return type:

uint64 | ndarray

class pytyche.bcf.diagnostics.topology.TopologyHistory(hashes, move_type, accepted, internal_node_count, burnin_iters, num_trees_mu)[source]

Bases: object

Per-iter retention of per-tree topology hashes + move metadata.

Populated by fit_joint_hurdle_bcf when GPUBCFConfig.retain_topology_history=True. All array fields share the leading shape (n_chains, n_iters_total, n_trees_total), where n_iters_total = num_burnin + num_mcmc and n_trees_total = num_trees_mu + num_trees_tau. The first num_trees_mu columns of the trees axis are the prognostic (mu) forest; the remaining columns are the treatment (tau) forest.

Aggregate metrics condition on iter >= burnin_iters to exclude transient burn-in behaviour.

hashes

uint64, shape (n_chains, n_iters_total, n_trees_total). Combined 64-bit topology hash per (chain, iter, tree).

Type:

np.ndarray

move_type

uint8, same shape as hashes. 0 = grow proposal, 1 = prune proposal; other values are tolerated but ignored by the proposal-rate metrics.

Type:

np.ndarray

accepted

bool_, same shape as hashes. Whether the proposed move was accepted at this (chain, iter, tree).

Type:

np.ndarray

internal_node_count

int32, same shape as hashes. Internal-node count of the tree after this iter’s update (used for size summaries).

Type:

np.ndarray

burnin_iters

Number of leading iters along axis=1 to discard when computing aggregate metrics.

Type:

int

num_trees_mu

Number of leading columns along the trees axis that belong to the mu (prognostic) forest. Remaining columns are the tau (treatment) forest.

Type:

int

Parameters:
  • hashes (ndarray)

  • move_type (ndarray)

  • accepted (ndarray)

  • internal_node_count (ndarray)

  • burnin_iters (int)

  • num_trees_mu (int)

pytyche.bcf.diagnostics.topology.compute_topology_metrics(history)[source]

Aggregate per-tree topology mobility / proposal-rate metrics.

For each base metric, emits three variants:

  • pooled — topology.<name> (all trees),

  • mu-only — topology.mu.<name> (first num_trees_mu columns),

  • tau-only — topology.tau.<name> (remaining columns).

Base metrics:

  • n_distinct_per_chain_tree_median — median across (chain, tree) of the distinct-hash count along the iter axis.

  • transition_rate_mean — mean fraction of adjacent iters whose hash flips, averaged across (chain, tree).

  • mean_run_length — mean length of constant-hash runs per (chain, tree); equals T / (n_transitions + 1).

  • internal_node_count_mean — mean internal-node count across all (chain, iter, tree) entries in the slice.

  • grow_accept_rate / prune_accept_rate — accepted / proposed. float('nan') when the slice contains no proposals of that type.

  • grow_proposal_count / grow_accepted_count / prune_proposal_count / prune_accepted_count — raw int counts.

All metrics condition on iter >= history.burnin_iters to exclude burn-in transients.

Parameters:

history (TopologyHistory) – Retained per-iter topology trace produced by fit_joint_hurdle_bcf.

Returns:

Flat dict with the 30 keys described above (10 base metrics x 3 slice prefixes).

Return type:

dict[str, float | int]