Source code for pytyche.bcf.hurdle.proposals

"""Hurdle BCF grow/prune proposals and forest initialisation.

These are internal helpers used by the joint hurdle MCMC loop.
"""

from __future__ import annotations

from typing import NamedTuple

import jax
import jax.numpy as jnp
import numpy as np
from bartz.grove import (
    is_actual_leaf,
    is_leaves_parent,
)
from bartz.grove import (
    traverse_forest as _traverse_forest,
)
from bartz.mcmcstep import OutcomeType, init, make_p_nonterminal
from jax import random


[docs] class HurdleMove(NamedTuple): """Per-tree GROW/PRUNE move metadata for the parallel-LML fast path. Mirrors bartz's ``Moves`` Module but as a JAX-pytree NamedTuple so it threads through ``vmap`` and ``lax.scan`` carries naturally. Carries the proposal direction (``grow``), the heap indices of the node and its two children, the structural + proposal log-ratio for Metropolis-Hastings, a gating bit (``allowed``), and the U(0,1] draw used for the accept comparison. """ grow: jax.Array node: jax.Array left: jax.Array right: jax.Array log_ratio_struct: jax.Array allowed: jax.Array log_u: jax.Array
def _hurdle_propose_grow( key: jax.Array, var_tree: jax.Array, split_tree: jax.Array, X_binned: jax.Array, max_split: jax.Array, p_nonterminal: jax.Array, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: """Propose a GROW move on a single hurdle tree. Picks a random leaf, selects a random variable and split point, and converts the leaf into an internal node. Returns the proposed tree structure and the structural log-ratio for MH. Parameters ---------- key : PRNG key var_tree : (half_size,) current split variables split_tree : (half_size,) current split points (0 = leaf) X_binned : (p, n) binned covariate matrix max_split : (p,) max split index per variable p_nonterminal : (max_depth-1,) depth prior probabilities Returns ------- (new_var_tree, new_split_tree, log_ratio_struct, allowed) """ half_size = split_tree.shape[0] k1, k2, k3 = random.split(key, 3) # Find growable leaves (actual leaves in the internal-node portion) leaf_mask = is_actual_leaf(split_tree) # (half_size,) n_leaves = jnp.sum(leaf_mask) # Pick a random leaf leaf_probs = jnp.where(leaf_mask, 1.0, 0.0) leaf_probs = leaf_probs / jnp.maximum(leaf_probs.sum(), 1.0) leaf_idx = random.choice(k1, half_size, p=leaf_probs) # Compute depth of selected leaf: floor(log2(leaf_idx)) depth = jnp.floor(jnp.log2(jnp.maximum(leaf_idx, 1).astype(jnp.float32))).astype(jnp.int32) # Check: leaf's children must fit in the half_size array left_child = leaf_idx * 2 allowed = (left_child + 1 < half_size) & (n_leaves > 0) # Pick random variable and split point p = max_split.shape[0] var = random.randint(k2, (), 0, p) split_val = random.randint(k3, (), 1, jnp.maximum(max_split[var], 2).astype(jnp.int32)) # Construct proposed tree new_var_tree = var_tree.at[leaf_idx].set(var.astype(var_tree.dtype)) new_split_tree = split_tree.at[leaf_idx].set(split_val.astype(split_tree.dtype)) # Structural prior ratio for grow # P(split at depth d) / P(no split at depth d) # + 2 * log(1 - P(split at depth d+1)) p_split = p_nonterminal[jnp.minimum(depth, p_nonterminal.shape[0] - 1)] child_depth = depth + 1 p_split_child = jnp.where( child_depth < p_nonterminal.shape[0], p_nonterminal[jnp.minimum(child_depth, p_nonterminal.shape[0] - 1)], 0.0, ) log_struct = ( jnp.log(p_split) - jnp.log(1 - p_split) + 2 * jnp.log(1 - p_split_child) ) # Proposal ratio: grow picks 1/n_leaves, reverse (prune) picks 1/n_SI_after n_si_after = jnp.sum(is_leaves_parent(new_split_tree)) log_proposal = ( jnp.log(1.0 / jnp.maximum(n_si_after, 1).astype(jnp.float32)) - jnp.log(1.0 / jnp.maximum(n_leaves, 1).astype(jnp.float32)) ) log_ratio_struct = log_struct + log_proposal return new_var_tree, new_split_tree, log_ratio_struct, allowed def _hurdle_propose_prune( key: jax.Array, var_tree: jax.Array, split_tree: jax.Array, p_nonterminal: jax.Array, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: """Propose a PRUNE move on a single hurdle tree. Picks a random singly-internal node (both children are leaves), converts it to a leaf by zeroing its split. Returns proposed tree and full structural+proposal log-ratio for MH (matching grow's return convention). Returns ------- (new_var_tree, new_split_tree, log_ratio_full, allowed) """ half_size = split_tree.shape[0] si_mask = is_leaves_parent(split_tree) n_si = jnp.sum(si_mask) si_probs = jnp.where(si_mask, 1.0, 0.0) si_probs = si_probs / jnp.maximum(si_probs.sum(), 1.0) node_idx = random.choice(key, half_size, p=si_probs) allowed = n_si > 0 new_split_tree = split_tree.at[node_idx].set(0) new_var_tree = var_tree.at[node_idx].set(0) # Structural prior ratio (negative of grow's) depth = jnp.floor(jnp.log2(jnp.maximum(node_idx, 1).astype(jnp.float32))).astype(jnp.int32) p_split = p_nonterminal[jnp.minimum(depth, p_nonterminal.shape[0] - 1)] child_depth = depth + 1 p_split_child = jnp.where( child_depth < p_nonterminal.shape[0], p_nonterminal[jnp.minimum(child_depth, p_nonterminal.shape[0] - 1)], 0.0, ) log_struct = ( jnp.log(1 - p_split) - jnp.log(p_split) - 2 * jnp.log(1 - p_split_child) ) # Proposal ratio n_leaves_after = jnp.sum(is_actual_leaf(split_tree)) - 1 log_proposal = ( jnp.log(1.0 / jnp.maximum(n_leaves_after, 1).astype(jnp.float32)) - jnp.log(1.0 / jnp.maximum(n_si, 1).astype(jnp.float32)) ) return new_var_tree, new_split_tree, log_struct + log_proposal, allowed def _propose_single_tree( key: jax.Array, var_tree: jax.Array, split_tree: jax.Array, X_binned: jax.Array, max_split: jax.Array, p_nonterminal: jax.Array, min_samples_leaf: jax.Array, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: """Propose a grow/prune move for one tree and pre-traverse. Returns (prop_var, prop_split, prop_leaf_indices, log_ratio, allowed, log_u). Use ``_propose_single_tree_with_move`` for the move-aware fast path that also emits a ``HurdleMove`` struct. """ (prop_var, prop_split, prop_li, log_ratio, allowed, log_u, _move) = ( _propose_single_tree_with_move( key, var_tree, split_tree, X_binned, max_split, p_nonterminal, min_samples_leaf, ) ) return prop_var, prop_split, prop_li, log_ratio, allowed, log_u def _propose_single_tree_with_move( key: jax.Array, var_tree: jax.Array, split_tree: jax.Array, X_binned: jax.Array, max_split: jax.Array, p_nonterminal: jax.Array, min_samples_leaf: jax.Array, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, HurdleMove]: """Propose a grow/prune move for one tree and pre-traverse the proposal. Like ``_propose_single_tree`` but additionally returns a ``HurdleMove`` pytree carrying the heap indices of the affected node and its children plus the structural log-ratio. Used by the parallel-LML fast path (``per_leaf_gamma=False``). """ k_move, k_grow, k_prune, k_u = random.split(key, 4) half_size = split_tree.shape[0] tree_size_static = 2 * half_size li_dtype = jnp.uint8 if tree_size_static <= 256 else jnp.uint16 n_si = jnp.sum(is_leaves_parent(split_tree)) force_grow = n_si == 0 do_grow = jnp.where(force_grow, True, random.uniform(k_move) < 0.5) # Re-derive the GROW move's heap indices alongside the proposal so the # caller can locate move.node/left/right without rebuilding them later. k1_grow, k2_grow, k3_grow = random.split(k_grow, 3) leaf_mask = is_actual_leaf(split_tree) leaf_probs = jnp.where(leaf_mask, 1.0, 0.0) leaf_probs = leaf_probs / jnp.maximum(leaf_probs.sum(), 1.0) grow_node = random.choice(k1_grow, half_size, p=leaf_probs) grow_left = grow_node * 2 grow_right = grow_node * 2 + 1 p = max_split.shape[0] grow_var_idx = random.randint(k2_grow, (), 0, p) grow_split_val = random.randint( k3_grow, (), 1, jnp.maximum(max_split[grow_var_idx], 2).astype(jnp.int32), ) new_var_tree_grow = var_tree.at[grow_node].set(grow_var_idx.astype(var_tree.dtype)) new_split_tree_grow = split_tree.at[grow_node].set(grow_split_val.astype(split_tree.dtype)) depth_grow = jnp.floor( jnp.log2(jnp.maximum(grow_node, 1).astype(jnp.float32)), ).astype(jnp.int32) p_split = p_nonterminal[jnp.minimum(depth_grow, p_nonterminal.shape[0] - 1)] child_depth = depth_grow + 1 p_split_child = jnp.where( child_depth < p_nonterminal.shape[0], p_nonterminal[jnp.minimum(child_depth, p_nonterminal.shape[0] - 1)], 0.0, ) grow_log_struct = ( jnp.log(p_split) - jnp.log(1 - p_split) + 2 * jnp.log(1 - p_split_child) ) n_leaves = jnp.sum(leaf_mask) n_si_after_grow = jnp.sum(is_leaves_parent(new_split_tree_grow)) grow_log_proposal = ( jnp.log(1.0 / jnp.maximum(n_si_after_grow, 1).astype(jnp.float32)) - jnp.log(1.0 / jnp.maximum(n_leaves, 1).astype(jnp.float32)) ) grow_lr = grow_log_struct + grow_log_proposal grow_ok = (grow_left + 1 < half_size) & (n_leaves > 0) # PRUNE proposal (re-derived alongside). si_mask = is_leaves_parent(split_tree) si_probs = jnp.where(si_mask, 1.0, 0.0) si_probs = si_probs / jnp.maximum(si_probs.sum(), 1.0) prune_node = random.choice(k_prune, half_size, p=si_probs) prune_left = prune_node * 2 prune_right = prune_node * 2 + 1 new_split_tree_prune = split_tree.at[prune_node].set(0) new_var_tree_prune = var_tree.at[prune_node].set(0) depth_prune = jnp.floor( jnp.log2(jnp.maximum(prune_node, 1).astype(jnp.float32)), ).astype(jnp.int32) p_split_p = p_nonterminal[jnp.minimum(depth_prune, p_nonterminal.shape[0] - 1)] child_depth_p = depth_prune + 1 p_split_child_p = jnp.where( child_depth_p < p_nonterminal.shape[0], p_nonterminal[jnp.minimum(child_depth_p, p_nonterminal.shape[0] - 1)], 0.0, ) prune_log_struct = ( jnp.log(1 - p_split_p) - jnp.log(p_split_p) - 2 * jnp.log(1 - p_split_child_p) ) n_leaves_after_prune = jnp.sum(is_actual_leaf(split_tree)) - 1 prune_log_proposal = ( jnp.log(1.0 / jnp.maximum(n_leaves_after_prune, 1).astype(jnp.float32)) - jnp.log(1.0 / jnp.maximum(n_si, 1).astype(jnp.float32)) ) prune_lr = prune_log_struct + prune_log_proposal prune_ok = n_si > 0 # Select between grow and prune. prop_var = jnp.where(do_grow, new_var_tree_grow, new_var_tree_prune) prop_split = jnp.where(do_grow, new_split_tree_grow, new_split_tree_prune) log_ratio = jnp.where(do_grow, grow_lr, prune_lr) allowed = jnp.where(do_grow, grow_ok, prune_ok) log_u = jnp.log(random.uniform(k_u)) move_node = jnp.where(do_grow, grow_node, prune_node).astype(jnp.int32) move_left = jnp.where(do_grow, grow_left, prune_left).astype(jnp.int32) move_right = jnp.where(do_grow, grow_right, prune_right).astype(jnp.int32) # Pre-traverse proposed tree. prop_li = _traverse_forest( X_binned, prop_var[jnp.newaxis], prop_split[jnp.newaxis], )[0] prop_li = prop_li.astype(li_dtype) # Enforce min_samples_leaf: reject proposals where any leaf has too few obs n = X_binned.shape[1] leaf_counts = jnp.zeros(tree_size_static, dtype=jnp.int32).at[prop_li].add(1) occupied = leaf_counts > 0 min_count = jnp.min(jnp.where(occupied, leaf_counts, n)) allowed = allowed & (min_count >= min_samples_leaf) move = HurdleMove( grow=do_grow, node=move_node, left=move_left, right=move_right, log_ratio_struct=log_ratio, allowed=allowed, log_u=log_u, ) return prop_var, prop_split, prop_li, log_ratio, allowed, log_u, move def _init_hurdle_forest( X_binned: jax.Array, max_split: jax.Array, y_placeholder: np.ndarray, n: int, num_trees: int, alpha: float, beta: float, leaf_prior_inv: float, max_depth: int, ): """Initialize a bartz forest state for hurdle trees. We reuse bartz's init() purely as a tree-array scaffolding factory: the returned State is consumed only for its ``forest.{var_tree, split_tree, leaf_tree, leaf_indices, p_nonterminal}`` arrays. The joint hurdle MCMC (``_hurdle_step_forest``) reads none of the State's outcome fields and calls no bartz step primitive — it home-rolls the entire inference. The conversion (alpha) / severity (beta) channels live in pytyche's dual-leaf carry, NOT in this scalar bartz State. """ # bartz `mcmcstep.init` donates its array inputs (via _remove_weak_types # with donate='all') and stores them as state fields with the same buffer. # Copy at the wrapper boundary — mirrors bartz's own _interface.py — so # callers can reuse X_binned / max_split after this returns. # # outcome_type=continuous is explicit (provable no-op — it is already the # bartz default) to document that this is a scalar-continuous scaffolding # factory, not an encoding of the hurdle channels — NOT a mixed-outcome init. return init( X=jnp.array(X_binned), y=jnp.asarray(y_placeholder, dtype=jnp.float32), outcome_type=OutcomeType.continuous, offset=jnp.float32(0.0), max_split=jnp.array(max_split), num_trees=num_trees, p_nonterminal=make_p_nonterminal(max_depth, alpha, beta), leaf_prior_cov_inv=jnp.float32(leaf_prior_inv), error_cov_df=jnp.float32(3.0), error_cov_scale=jnp.float32(3.0), )