"""Hurdle BCF grow/prune proposals and forest initialisation.
These are internal helpers used by the joint hurdle MCMC loop.
"""
from __future__ import annotations
from typing import NamedTuple
import jax
import jax.numpy as jnp
import numpy as np
from bartz.grove import (
is_actual_leaf,
is_leaves_parent,
)
from bartz.grove import (
traverse_forest as _traverse_forest,
)
from bartz.mcmcstep import OutcomeType, init, make_p_nonterminal
from jax import random
[docs]
class HurdleMove(NamedTuple):
"""Per-tree GROW/PRUNE move metadata for the parallel-LML fast path.
Mirrors bartz's ``Moves`` Module but as a JAX-pytree NamedTuple so it
threads through ``vmap`` and ``lax.scan`` carries naturally.
Carries the proposal direction (``grow``), the heap indices of the
node and its two children, the structural + proposal log-ratio for
Metropolis-Hastings, a gating bit (``allowed``), and the U(0,1] draw
used for the accept comparison.
"""
grow: jax.Array
node: jax.Array
left: jax.Array
right: jax.Array
log_ratio_struct: jax.Array
allowed: jax.Array
log_u: jax.Array
def _hurdle_propose_grow(
key: jax.Array,
var_tree: jax.Array,
split_tree: jax.Array,
X_binned: jax.Array,
max_split: jax.Array,
p_nonterminal: jax.Array,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
"""Propose a GROW move on a single hurdle tree.
Picks a random leaf, selects a random variable and split point,
and converts the leaf into an internal node. Returns the proposed
tree structure and the structural log-ratio for MH.
Parameters
----------
key : PRNG key
var_tree : (half_size,) current split variables
split_tree : (half_size,) current split points (0 = leaf)
X_binned : (p, n) binned covariate matrix
max_split : (p,) max split index per variable
p_nonterminal : (max_depth-1,) depth prior probabilities
Returns
-------
(new_var_tree, new_split_tree, log_ratio_struct, allowed)
"""
half_size = split_tree.shape[0]
k1, k2, k3 = random.split(key, 3)
# Find growable leaves (actual leaves in the internal-node portion)
leaf_mask = is_actual_leaf(split_tree) # (half_size,)
n_leaves = jnp.sum(leaf_mask)
# Pick a random leaf
leaf_probs = jnp.where(leaf_mask, 1.0, 0.0)
leaf_probs = leaf_probs / jnp.maximum(leaf_probs.sum(), 1.0)
leaf_idx = random.choice(k1, half_size, p=leaf_probs)
# Compute depth of selected leaf: floor(log2(leaf_idx))
depth = jnp.floor(jnp.log2(jnp.maximum(leaf_idx, 1).astype(jnp.float32))).astype(jnp.int32)
# Check: leaf's children must fit in the half_size array
left_child = leaf_idx * 2
allowed = (left_child + 1 < half_size) & (n_leaves > 0)
# Pick random variable and split point
p = max_split.shape[0]
var = random.randint(k2, (), 0, p)
split_val = random.randint(k3, (), 1, jnp.maximum(max_split[var], 2).astype(jnp.int32))
# Construct proposed tree
new_var_tree = var_tree.at[leaf_idx].set(var.astype(var_tree.dtype))
new_split_tree = split_tree.at[leaf_idx].set(split_val.astype(split_tree.dtype))
# Structural prior ratio for grow
# P(split at depth d) / P(no split at depth d)
# + 2 * log(1 - P(split at depth d+1))
p_split = p_nonterminal[jnp.minimum(depth, p_nonterminal.shape[0] - 1)]
child_depth = depth + 1
p_split_child = jnp.where(
child_depth < p_nonterminal.shape[0],
p_nonterminal[jnp.minimum(child_depth, p_nonterminal.shape[0] - 1)],
0.0,
)
log_struct = (
jnp.log(p_split) - jnp.log(1 - p_split)
+ 2 * jnp.log(1 - p_split_child)
)
# Proposal ratio: grow picks 1/n_leaves, reverse (prune) picks 1/n_SI_after
n_si_after = jnp.sum(is_leaves_parent(new_split_tree))
log_proposal = (
jnp.log(1.0 / jnp.maximum(n_si_after, 1).astype(jnp.float32))
- jnp.log(1.0 / jnp.maximum(n_leaves, 1).astype(jnp.float32))
)
log_ratio_struct = log_struct + log_proposal
return new_var_tree, new_split_tree, log_ratio_struct, allowed
def _hurdle_propose_prune(
key: jax.Array,
var_tree: jax.Array,
split_tree: jax.Array,
p_nonterminal: jax.Array,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
"""Propose a PRUNE move on a single hurdle tree.
Picks a random singly-internal node (both children are leaves),
converts it to a leaf by zeroing its split. Returns proposed tree
and full structural+proposal log-ratio for MH (matching grow's
return convention).
Returns
-------
(new_var_tree, new_split_tree, log_ratio_full, allowed)
"""
half_size = split_tree.shape[0]
si_mask = is_leaves_parent(split_tree)
n_si = jnp.sum(si_mask)
si_probs = jnp.where(si_mask, 1.0, 0.0)
si_probs = si_probs / jnp.maximum(si_probs.sum(), 1.0)
node_idx = random.choice(key, half_size, p=si_probs)
allowed = n_si > 0
new_split_tree = split_tree.at[node_idx].set(0)
new_var_tree = var_tree.at[node_idx].set(0)
# Structural prior ratio (negative of grow's)
depth = jnp.floor(jnp.log2(jnp.maximum(node_idx, 1).astype(jnp.float32))).astype(jnp.int32)
p_split = p_nonterminal[jnp.minimum(depth, p_nonterminal.shape[0] - 1)]
child_depth = depth + 1
p_split_child = jnp.where(
child_depth < p_nonterminal.shape[0],
p_nonterminal[jnp.minimum(child_depth, p_nonterminal.shape[0] - 1)],
0.0,
)
log_struct = (
jnp.log(1 - p_split) - jnp.log(p_split)
- 2 * jnp.log(1 - p_split_child)
)
# Proposal ratio
n_leaves_after = jnp.sum(is_actual_leaf(split_tree)) - 1
log_proposal = (
jnp.log(1.0 / jnp.maximum(n_leaves_after, 1).astype(jnp.float32))
- jnp.log(1.0 / jnp.maximum(n_si, 1).astype(jnp.float32))
)
return new_var_tree, new_split_tree, log_struct + log_proposal, allowed
def _propose_single_tree(
key: jax.Array,
var_tree: jax.Array,
split_tree: jax.Array,
X_binned: jax.Array,
max_split: jax.Array,
p_nonterminal: jax.Array,
min_samples_leaf: jax.Array,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
"""Propose a grow/prune move for one tree and pre-traverse.
Returns (prop_var, prop_split, prop_leaf_indices, log_ratio, allowed, log_u).
Use ``_propose_single_tree_with_move`` for the move-aware fast path that
also emits a ``HurdleMove`` struct.
"""
(prop_var, prop_split, prop_li, log_ratio, allowed, log_u, _move) = (
_propose_single_tree_with_move(
key, var_tree, split_tree, X_binned, max_split, p_nonterminal,
min_samples_leaf,
)
)
return prop_var, prop_split, prop_li, log_ratio, allowed, log_u
def _propose_single_tree_with_move(
key: jax.Array,
var_tree: jax.Array,
split_tree: jax.Array,
X_binned: jax.Array,
max_split: jax.Array,
p_nonterminal: jax.Array,
min_samples_leaf: jax.Array,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, HurdleMove]:
"""Propose a grow/prune move for one tree and pre-traverse the proposal.
Like ``_propose_single_tree`` but additionally returns a ``HurdleMove``
pytree carrying the heap indices of the affected node and its children
plus the structural log-ratio. Used by the parallel-LML fast path
(``per_leaf_gamma=False``).
"""
k_move, k_grow, k_prune, k_u = random.split(key, 4)
half_size = split_tree.shape[0]
tree_size_static = 2 * half_size
li_dtype = jnp.uint8 if tree_size_static <= 256 else jnp.uint16
n_si = jnp.sum(is_leaves_parent(split_tree))
force_grow = n_si == 0
do_grow = jnp.where(force_grow, True, random.uniform(k_move) < 0.5)
# Re-derive the GROW move's heap indices alongside the proposal so the
# caller can locate move.node/left/right without rebuilding them later.
k1_grow, k2_grow, k3_grow = random.split(k_grow, 3)
leaf_mask = is_actual_leaf(split_tree)
leaf_probs = jnp.where(leaf_mask, 1.0, 0.0)
leaf_probs = leaf_probs / jnp.maximum(leaf_probs.sum(), 1.0)
grow_node = random.choice(k1_grow, half_size, p=leaf_probs)
grow_left = grow_node * 2
grow_right = grow_node * 2 + 1
p = max_split.shape[0]
grow_var_idx = random.randint(k2_grow, (), 0, p)
grow_split_val = random.randint(
k3_grow, (), 1, jnp.maximum(max_split[grow_var_idx], 2).astype(jnp.int32),
)
new_var_tree_grow = var_tree.at[grow_node].set(grow_var_idx.astype(var_tree.dtype))
new_split_tree_grow = split_tree.at[grow_node].set(grow_split_val.astype(split_tree.dtype))
depth_grow = jnp.floor(
jnp.log2(jnp.maximum(grow_node, 1).astype(jnp.float32)),
).astype(jnp.int32)
p_split = p_nonterminal[jnp.minimum(depth_grow, p_nonterminal.shape[0] - 1)]
child_depth = depth_grow + 1
p_split_child = jnp.where(
child_depth < p_nonterminal.shape[0],
p_nonterminal[jnp.minimum(child_depth, p_nonterminal.shape[0] - 1)],
0.0,
)
grow_log_struct = (
jnp.log(p_split) - jnp.log(1 - p_split)
+ 2 * jnp.log(1 - p_split_child)
)
n_leaves = jnp.sum(leaf_mask)
n_si_after_grow = jnp.sum(is_leaves_parent(new_split_tree_grow))
grow_log_proposal = (
jnp.log(1.0 / jnp.maximum(n_si_after_grow, 1).astype(jnp.float32))
- jnp.log(1.0 / jnp.maximum(n_leaves, 1).astype(jnp.float32))
)
grow_lr = grow_log_struct + grow_log_proposal
grow_ok = (grow_left + 1 < half_size) & (n_leaves > 0)
# PRUNE proposal (re-derived alongside).
si_mask = is_leaves_parent(split_tree)
si_probs = jnp.where(si_mask, 1.0, 0.0)
si_probs = si_probs / jnp.maximum(si_probs.sum(), 1.0)
prune_node = random.choice(k_prune, half_size, p=si_probs)
prune_left = prune_node * 2
prune_right = prune_node * 2 + 1
new_split_tree_prune = split_tree.at[prune_node].set(0)
new_var_tree_prune = var_tree.at[prune_node].set(0)
depth_prune = jnp.floor(
jnp.log2(jnp.maximum(prune_node, 1).astype(jnp.float32)),
).astype(jnp.int32)
p_split_p = p_nonterminal[jnp.minimum(depth_prune, p_nonterminal.shape[0] - 1)]
child_depth_p = depth_prune + 1
p_split_child_p = jnp.where(
child_depth_p < p_nonterminal.shape[0],
p_nonterminal[jnp.minimum(child_depth_p, p_nonterminal.shape[0] - 1)],
0.0,
)
prune_log_struct = (
jnp.log(1 - p_split_p) - jnp.log(p_split_p)
- 2 * jnp.log(1 - p_split_child_p)
)
n_leaves_after_prune = jnp.sum(is_actual_leaf(split_tree)) - 1
prune_log_proposal = (
jnp.log(1.0 / jnp.maximum(n_leaves_after_prune, 1).astype(jnp.float32))
- jnp.log(1.0 / jnp.maximum(n_si, 1).astype(jnp.float32))
)
prune_lr = prune_log_struct + prune_log_proposal
prune_ok = n_si > 0
# Select between grow and prune.
prop_var = jnp.where(do_grow, new_var_tree_grow, new_var_tree_prune)
prop_split = jnp.where(do_grow, new_split_tree_grow, new_split_tree_prune)
log_ratio = jnp.where(do_grow, grow_lr, prune_lr)
allowed = jnp.where(do_grow, grow_ok, prune_ok)
log_u = jnp.log(random.uniform(k_u))
move_node = jnp.where(do_grow, grow_node, prune_node).astype(jnp.int32)
move_left = jnp.where(do_grow, grow_left, prune_left).astype(jnp.int32)
move_right = jnp.where(do_grow, grow_right, prune_right).astype(jnp.int32)
# Pre-traverse proposed tree.
prop_li = _traverse_forest(
X_binned, prop_var[jnp.newaxis], prop_split[jnp.newaxis],
)[0]
prop_li = prop_li.astype(li_dtype)
# Enforce min_samples_leaf: reject proposals where any leaf has too few obs
n = X_binned.shape[1]
leaf_counts = jnp.zeros(tree_size_static, dtype=jnp.int32).at[prop_li].add(1)
occupied = leaf_counts > 0
min_count = jnp.min(jnp.where(occupied, leaf_counts, n))
allowed = allowed & (min_count >= min_samples_leaf)
move = HurdleMove(
grow=do_grow,
node=move_node,
left=move_left,
right=move_right,
log_ratio_struct=log_ratio,
allowed=allowed,
log_u=log_u,
)
return prop_var, prop_split, prop_li, log_ratio, allowed, log_u, move
def _init_hurdle_forest(
X_binned: jax.Array,
max_split: jax.Array,
y_placeholder: np.ndarray,
n: int,
num_trees: int,
alpha: float,
beta: float,
leaf_prior_inv: float,
max_depth: int,
):
"""Initialize a bartz forest state for hurdle trees.
We reuse bartz's init() purely as a tree-array scaffolding factory: the
returned State is consumed only for its ``forest.{var_tree, split_tree,
leaf_tree, leaf_indices, p_nonterminal}`` arrays. The joint hurdle MCMC
(``_hurdle_step_forest``) reads none of the State's outcome fields and
calls no bartz step primitive — it home-rolls the entire inference. The
conversion (alpha) / severity (beta) channels live in pytyche's dual-leaf
carry, NOT in this scalar bartz State.
"""
# bartz `mcmcstep.init` donates its array inputs (via _remove_weak_types
# with donate='all') and stores them as state fields with the same buffer.
# Copy at the wrapper boundary — mirrors bartz's own _interface.py — so
# callers can reuse X_binned / max_split after this returns.
#
# outcome_type=continuous is explicit (provable no-op — it is already the
# bartz default) to document that this is a scalar-continuous scaffolding
# factory, not an encoding of the hurdle channels — NOT a mixed-outcome init.
return init(
X=jnp.array(X_binned),
y=jnp.asarray(y_placeholder, dtype=jnp.float32),
outcome_type=OutcomeType.continuous,
offset=jnp.float32(0.0),
max_split=jnp.array(max_split),
num_trees=num_trees,
p_nonterminal=make_p_nonterminal(max_depth, alpha, beta),
leaf_prior_cov_inv=jnp.float32(leaf_prior_inv),
error_cov_df=jnp.float32(3.0),
error_cov_scale=jnp.float32(3.0),
)