"""Pure-math sufficient statistics and log-marginal likelihoods for the GPU BCF.
This module holds the closed-form pieces of the joint hurdle BCF posterior:
basis-weighted sufficient statistics aggregated per leaf, the matching
log-marginal likelihoods (Normal-Normal and Normal-Gamma conjugacy), and the
moment-matched gamma prior used to set ``(a_tau, b_tau)``.
It is "pure math" in the sense that every function here depends only on its
arguments and on JAX/numpy/scipy primitives — no module-level state, no GPU
device handles, no JIT-compiled kernels of its own.
Import graph
------------
This module imports ``_fused_scatter_add`` from ``pytyche.bcf.kernels`` and
nothing else from sibling GPU BCF modules. The orchestrator and the GFR
module import FROM here, never the other way around — see
``pytyche.bcf.kernels`` for the same convention applied at the bottom of the
import graph.
Contents
--------
``_solve_gamma_prior`` — Linero moment-matching for ``(a_gamma, b_gamma)``.
``_hurdle_sufficient_stats`` — basis-weighted (Normal-Normal) per-leaf stats.
``_hurdle_sufficient_stats_ng`` — precision-weighted (Normal-Gamma) per-leaf stats.
``_log_marginal_leaves`` — per-leaf log-marginal under Normal-Normal conjugacy.
``_log_marginal_leaves_ng`` — per-leaf log-marginal under Normal-Gamma conjugacy.
``_log_marginal_hurdle_tree`` — total joint log-marginal for a hurdle tree (NN severity).
``_log_marginal_hurdle_tree_ng`` — total joint log-marginal for a hurdle tree (NG severity).
"""
from __future__ import annotations
from typing import NamedTuple, cast
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import brentq
from scipy.special import digamma, polygamma
from pytyche.bcf.kernels import _fused_scatter_add
# ---------------------------------------------------------------------------
# Per-iteration parallel-stage factors (mirrors bartz's PreLkV / PreLk / PreLf)
# ---------------------------------------------------------------------------
[docs]
class HurdlePreLkv(NamedTuple):
"""Per-tree likelihood-ratio factors for one channel of the hurdle lml.
Mirrors bartz's ``PreLkV`` Module as a JAX-pytree NamedTuple. All fields
have shape ``(num_trees,)`` after the parallel pre-compute stage.
Fields T_left/T_right/T_total are the conjugate-Normal-Normal denominators
``sigma2 + sum_h2 * tau_prior`` evaluated at the move's left, right, and
parent leaves; log_sqrt_term is the log-determinant ratio
``0.5 * log(sigma2 * T_total / (T_left * T_right))``; exp_factor is the
shared scalar ``tau_prior / (2 * sigma2)`` that multiplies the residual
quadratic in the O(1) ratio.
"""
T_left: jax.Array
T_right: jax.Array
T_total: jax.Array
log_sqrt_term: jax.Array
exp_factor: jax.Array
[docs]
class HurdlePreLf(NamedTuple):
"""Per-tree leaf-posterior factors for one channel.
Mirrors bartz's ``PreLf``. For each tree's tree_size leaves:
new_leaf_value[t, l]
= mean_factor[t, l] * sum_resid_at_leaf[t, l] + centered_leaves[t, l]
Both fields zeroed where ``count_trees[t, l] == 0`` so empty leaves get
zero alpha (matches the existing ``jnp.where(count > 0, ..., 0)`` pattern).
"""
mean_factor: jax.Array # (num_trees, tree_size)
centered_leaves: jax.Array # (num_trees, tree_size)
def _precompute_lk_one_channel(
count_left: jax.Array,
count_right: jax.Array,
sigma2: jax.Array,
tau_prior: jax.Array,
basis_sq: jax.Array,
) -> HurdlePreLkv:
"""Build a ``HurdlePreLkv`` for one channel from per-tree leaf counts.
All inputs except ``sigma2``, ``tau_prior``, ``basis_sq`` are
``(num_trees,)``. ``basis_sq`` is the uniform per-observation
``basis²`` (1.0 for the mu forest, 0.25 for the tau forest's ±0.5 basis).
"""
sum_h2_L = count_left * basis_sq
sum_h2_R = count_right * basis_sq
sum_h2_T = sum_h2_L + sum_h2_R
T_L = sigma2 + sum_h2_L * tau_prior
T_R = sigma2 + sum_h2_R * tau_prior
T_T = sigma2 + sum_h2_T * tau_prior
log_sqrt_term = 0.5 * jnp.log(sigma2 * T_T / jnp.maximum(T_L * T_R, 1e-30))
exp_factor = tau_prior / (2.0 * sigma2)
return HurdlePreLkv(
T_left=T_L, T_right=T_R, T_total=T_T,
log_sqrt_term=log_sqrt_term, exp_factor=exp_factor,
)
def _precompute_leaf_one_channel(
count_trees: jax.Array,
sigma2: jax.Array,
tau_prior: jax.Array,
basis_sq: jax.Array,
noise: jax.Array,
) -> HurdlePreLf:
"""Build a ``HurdlePreLf`` for one channel from per-tree per-leaf counts.
The Normal-Normal posterior at a leaf with sufficient stat sum_hr and
integrated count is::
precision_post = sum_h2 / sigma2 + 1 / tau_prior = T / (sigma2 * tau_prior)
var_post = sigma2 * tau_prior / T
mean_factor = var_post / sigma2 = tau_prior / T
new_alpha = mean_factor * sum_hr + noise * sqrt(var_post)
For uninhabited leaves (``count == 0``) we zero both factors so the
sampled alpha is exactly 0, matching the existing
``jnp.where(count > 0, ..., 0.0)`` semantics.
"""
sum_h2 = count_trees * basis_sq
T = sigma2 + sum_h2 * tau_prior
T_safe = jnp.maximum(T, 1e-30)
var_post = sigma2 * tau_prior / T_safe
populated = count_trees > 0
mean_factor = jnp.where(populated, tau_prior / T_safe, 0.0)
centered = jnp.where(populated, noise * jnp.sqrt(var_post), 0.0)
return HurdlePreLf(mean_factor=mean_factor, centered_leaves=centered)
def _precompute_likelihood_terms_hurdle(
count_trees_conv: jax.Array,
count_trees_sev: jax.Array,
moves_left: jax.Array,
moves_right: jax.Array,
tau_conv: jax.Array,
tau_sev: jax.Array,
sigma2_sev: jax.Array,
basis_sq: jax.Array,
) -> tuple[HurdlePreLkv, HurdlePreLkv]:
"""Pre-compute per-tree likelihood-ratio factors for both channels.
Hoisted out of the sequential per-tree scan: T_left, T_right, T_total,
log_sqrt_term, and exp_factor depend only on the proposal structure
(counts at move.left / move.right) and the channel hyperparameters,
not on residuals. Computed once per iteration in the parallel pre-compute
stage; the per-tree sequential body then reads scalars per tree.
Conversion channel: Normal-Normal with sigma² = 1.0.
Severity channel: Normal-Normal with sigma² = 1/tau_0 (the
``per_leaf_gamma=False`` path). The Normal-Gamma severity path keeps
its sequential dependency and does not consume these factors.
Parameters
----------
count_trees_conv, count_trees_sev : (num_trees, tree_size)
Leaf counts (channel-specific mask) under the deeper-of-old-vs-prop
leaf-index assignment per tree.
moves_left, moves_right : (num_trees,) int
Heap indices of the two children involved in each tree's move.
tau_conv, tau_sev : scalars — leaf-mean prior variances per channel.
sigma2_sev : scalar — severity error variance ``1/tau_0`` for the NN path.
basis_sq : scalar — uniform ``basis²`` (1.0 mu / 0.25 tau).
"""
sigma2_conv = jnp.float32(1.0)
count_left_conv = jnp.take_along_axis(
count_trees_conv, moves_left[:, None], axis=1,
).squeeze(1)
count_right_conv = jnp.take_along_axis(
count_trees_conv, moves_right[:, None], axis=1,
).squeeze(1)
prelkv_conv = _precompute_lk_one_channel(
count_left_conv, count_right_conv, sigma2_conv, tau_conv, basis_sq,
)
count_left_sev = jnp.take_along_axis(
count_trees_sev, moves_left[:, None], axis=1,
).squeeze(1)
count_right_sev = jnp.take_along_axis(
count_trees_sev, moves_right[:, None], axis=1,
).squeeze(1)
prelkv_sev = _precompute_lk_one_channel(
count_left_sev, count_right_sev, sigma2_sev, tau_sev, basis_sq,
)
return prelkv_conv, prelkv_sev
def _precompute_leaf_terms_hurdle(
count_trees_conv_merged: jax.Array,
count_trees_sev_merged: jax.Array,
tau_conv: jax.Array,
tau_sev: jax.Array,
sigma2_sev: jax.Array,
basis_sq: jax.Array,
noise_alpha: jax.Array,
noise_beta: jax.Array,
) -> tuple[HurdlePreLf, HurdlePreLf]:
"""Pre-compute per-tree per-leaf posterior factors for both channels.
Operates on count_trees AFTER the per-tree merge at ``move.node``: the
parent-leaf bin holds ``count_left + count_right`` so that leaf sampling
at move.node uses the count_total posterior (matches the existing
sequential-body behaviour at the rejected-grow / accepted-prune cases).
Returns ``(prelf_alpha, prelf_beta)``. Factors at empty leaves are
zeroed, so ``new_leaf = mean_factor * sum_hr + centered_leaves`` reduces
to zero there, matching ``jnp.where(count > 0, ..., 0.0)`` semantics.
"""
sigma2_conv = jnp.float32(1.0)
prelf_alpha = _precompute_leaf_one_channel(
count_trees_conv_merged, sigma2_conv, tau_conv, basis_sq, noise_alpha,
)
prelf_beta = _precompute_leaf_one_channel(
count_trees_sev_merged, sigma2_sev, tau_sev, basis_sq, noise_beta,
)
return prelf_alpha, prelf_beta
def _hurdle_resid_sums(
leaf_indices: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
tree_size: int,
) -> tuple[jax.Array, jax.Array]:
"""Per-leaf basis-weighted residual sums for both channels (one tree).
Single 2-channel fused scatter-add — the only residual-dependent
operation left in the sequential per-tree body once
``_precompute_likelihood_terms_hurdle`` and
``_precompute_leaf_terms_hurdle`` have hoisted the structure-only
factors into Stage 1.
Returns ``(sum_hr_conv, sum_hr_sev)`` each of shape ``(tree_size,)``.
Counts are NOT scattered here; they live in the parallel-stage
``count_trees_*`` arrays.
"""
mask_f = conv_mask.astype(jnp.float32)
h_sev = basis * mask_f
stack = jnp.stack([
basis * resid_conv, # sum_hr_conv (all visitors weighted by basis)
h_sev * resid_sev, # sum_hr_sev (mask × basis × resid)
]) # (2, n)
out = _fused_scatter_add(leaf_indices, stack, tree_size)
return out[0], out[1]
def _hurdle_resid_sums_mv(
leaf_indices: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
tree_size: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
"""Per-leaf first- and second-order statistics for the multivariate hurdle (one tree).
Multivariate analogue of ``_hurdle_resid_sums`` for a ``(n, K-1)`` basis
matrix. Per leaf L, accumulates four statistics over observations v in L:
First-order ``(K-1,)`` per leaf:
sum_hr_conv[L] = Σ_{v in L} basis_v * resid_conv_v
sum_hr_sev[L] = Σ_{v in L} mask_v * basis_v * resid_sev_v
Second-order ``(K-1, K-1)`` per leaf (outer products):
sum_h2_conv[L] = Σ_{v in L} outer(basis_v, basis_v)
sum_h2_sev[L] = Σ_{v in L} mask_v * outer(basis_v, basis_v)
``conv`` stats accumulate ALL visitors; ``sev`` stats accumulate converters
only (``conv_mask == True``). ``K-1`` is derived from ``basis.shape[1]``
so the function is valid for any K ≥ 2 (including K=2, i.e. K-1=1).
Parameters
----------
leaf_indices : (n,) int
Leaf index for each observation.
resid_conv : (n,) float32
Residuals for the conversion channel (all visitors).
resid_sev : (n,) float32
Residuals for the severity channel (used only for converters).
conv_mask : (n,) bool
True for converting observations.
basis : (n, K-1) float32
Reference-contrast basis matrix.
tree_size : int
Total nodes in the tree heap.
Returns
-------
sum_hr_conv : (tree_size, K-1)
sum_hr_sev : (tree_size, K-1)
sum_h2_conv : (tree_size, K-1, K-1)
sum_h2_sev : (tree_size, K-1, K-1)
"""
mask_f = conv_mask.astype(jnp.float32) # (n,)
# First-order stats: shape (tree_size, K-1)
sum_hr_conv = (
jnp.zeros((tree_size, basis.shape[1]), dtype=jnp.float32)
.at[leaf_indices]
.add(basis * resid_conv[:, None])
)
sum_hr_sev = (
jnp.zeros((tree_size, basis.shape[1]), dtype=jnp.float32)
.at[leaf_indices]
.add(basis * (mask_f * resid_sev)[:, None])
)
# Second-order stats (outer products): shape (tree_size, K-1, K-1)
sum_h2_conv = (
jnp.zeros((tree_size, basis.shape[1], basis.shape[1]), dtype=jnp.float32)
.at[leaf_indices]
.add(basis[:, :, None] * basis[:, None, :])
)
sum_h2_sev = (
jnp.zeros((tree_size, basis.shape[1], basis.shape[1]), dtype=jnp.float32)
.at[leaf_indices]
.add(mask_f[:, None, None] * basis[:, :, None] * basis[:, None, :])
)
return sum_hr_conv, sum_hr_sev, sum_h2_conv, sum_h2_sev
def _hurdle_count_trees(
leaf_indices: jax.Array,
conv_mask: jax.Array,
tree_size: int,
) -> tuple[jax.Array, jax.Array]:
"""Per-leaf observation counts, conv (all) and sev (converters), one tree.
Structure-only — depends on leaf assignments and the converter mask but
NOT on residuals. Lives in Stage 1's parallel pre-compute (vmap'd
across trees) so the sequential body never recomputes it.
Returns ``(count_conv, count_sev)`` each of shape ``(tree_size,)``.
"""
n = leaf_indices.shape[0]
all_true_f = jnp.ones(n, jnp.float32)
mask_f = conv_mask.astype(jnp.float32)
stack = jnp.stack([all_true_f, mask_f]) # (2, n)
out = _fused_scatter_add(leaf_indices, stack, tree_size)
return out[0], out[1]
def _log_marginal_hurdle_ratio(
sum_hr_left_conv: jax.Array,
sum_hr_right_conv: jax.Array,
sum_hr_total_conv: jax.Array,
sum_hr_left_sev: jax.Array,
sum_hr_right_sev: jax.Array,
sum_hr_total_sev: jax.Array,
prelkv_conv: HurdlePreLkv,
prelkv_sev: HurdlePreLkv,
sev_weight: float = 1.0,
) -> jax.Array:
"""O(1) per-tree joint log-lml ratio (proposed minus current, GROW direction).
Mirrors bartz's ``compute_likelihood_ratio_uv``
(``bartz-ref/src/bartz/mcmcstep/_step.py:1187-1199``) extended to the
hurdle two-channel model. Operates only on the move's three involved
leaves per tree per channel — no per-leaf sums over ``tree_size``.
For PRUNE moves negate externally (``log_lml(deep) - log_lml(shallow)``
becomes ``-result`` for the shallow → deep direction).
Inputs are scalars per tree (extracted from the sequential body's
per-tree sum_hr arrays at ``move.left/right`` and their sum); prelkv
values are also scalars per tree from the parallel pre-compute.
"""
log_lk_conv = prelkv_conv.log_sqrt_term + prelkv_conv.exp_factor * (
sum_hr_left_conv * sum_hr_left_conv / prelkv_conv.T_left
+ sum_hr_right_conv * sum_hr_right_conv / prelkv_conv.T_right
- sum_hr_total_conv * sum_hr_total_conv / prelkv_conv.T_total
)
log_lk_sev = prelkv_sev.log_sqrt_term + prelkv_sev.exp_factor * (
sum_hr_left_sev * sum_hr_left_sev / prelkv_sev.T_left
+ sum_hr_right_sev * sum_hr_right_sev / prelkv_sev.T_right
- sum_hr_total_sev * sum_hr_total_sev / prelkv_sev.T_total
)
return log_lk_conv + sev_weight * log_lk_sev
# ---------------------------------------------------------------------------
# Joint hurdle: sufficient statistics (scatter-add into leaf-indexed arrays)
# ---------------------------------------------------------------------------
def _hurdle_sufficient_stats(
leaf_indices: jax.Array,
resid: jax.Array,
mask: jax.Array,
basis: jax.Array,
tree_size: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
"""Compute basis-weighted sufficient statistics per leaf node.
For each leaf, accumulates:
sum_hr = sum(h_i * r_i) — basis-weighted residual
sum_h2 = sum(h_i^2) — basis precision
sum_r2 = sum(r_i^2) — raw SSE (for log-marginal)
count = sum(mask_i) — number of contributing observations
Parameters
----------
leaf_indices : (n,) which leaf each observation falls into
resid : (n,) residuals for this channel
mask : (n,) bool — which observations contribute (all True for conversion,
converter_mask for severity)
basis : (n,) basis weights (1.0 for mu forest, b0/b1 for tau forest)
tree_size : total nodes in the tree heap (2^d)
Returns
-------
(sum_hr, sum_h2, sum_r2, count) each of shape (tree_size,)
"""
mask_f = mask.astype(jnp.float32)
r_masked = resid * mask_f
h_masked = basis * mask_f
stack = jnp.stack([
h_masked * resid, # sum_hr
h_masked * basis, # sum_h2
r_masked * resid, # sum_r2
mask_f, # count
]) # (4, n)
results = _fused_scatter_add(leaf_indices, stack, tree_size) # (4, tree_size)
return results[0], results[1], results[2], results[3]
# ---------------------------------------------------------------------------
# Gamma prior moment-matching (Linero et al. 2020)
# ---------------------------------------------------------------------------
def _solve_gamma_prior(
num_trees_total: int,
var_tau: float = 0.5,
) -> tuple[float, float]:
"""Solve for (a_gamma, b_gamma) via Linero's moment-matching.
Target: trigamma(a) = var_tau / num_trees_total
Then: b = exp(digamma(a)) [ensures E[log(gamma)] = 0]
Uses brentq on f(a) = trigamma(a) - target (monotone decreasing, unique root).
"""
target_var = var_tau / num_trees_total
def f(a):
return float(polygamma(1, a) - target_var)
# brentq is overloaded float|tuple based on full_output; cast for pyright.
a_gamma = float(cast(float, brentq(f, 1e-4, 1e6)))
b_gamma = float(np.exp(digamma(a_gamma)))
return a_gamma, b_gamma
# ---------------------------------------------------------------------------
# Joint hurdle: precision-weighted sufficient statistics (Normal-Gamma)
# ---------------------------------------------------------------------------
def _hurdle_sufficient_stats_ng(
leaf_indices: jax.Array,
resid: jax.Array,
mask: jax.Array,
basis: jax.Array,
precision: jax.Array,
tree_size: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
"""Precision-weighted sufficient statistics per leaf for Normal-Gamma model.
For each leaf, accumulates:
sum_vr = Σ(v_i * h_i * r_i) — precision-weighted basis×residual
sum_vh2 = Σ(v_i * h_i²) — precision-weighted basis²
sum_vr2 = Σ(v_i * r_i²) — precision-weighted residual² (SSE)
sum_logv= Σ(log(v_i)) — log-precision (LogLT normalizer)
count = Σ(mask_i)
Where v_i = tau_0 * tau_hat_i (per-observation precision).
"""
mask_f = mask.astype(jnp.float32)
v = precision * mask_f
h = basis * mask_f
r = resid * mask_f
stack = jnp.stack([
v * h * resid, # sum_vr
v * h * basis, # sum_vh2 (note: h already has mask, so v*h*basis = v*mask*h²)
v * r * resid, # sum_vr2
jnp.log(jnp.maximum(precision, 1e-30)) * mask_f, # sum_logv
mask_f, # count
]) # (5, n)
results = _fused_scatter_add(leaf_indices, stack, tree_size) # (5, tree_size)
return results[0], results[1], results[2], results[3], results[4]
# ---------------------------------------------------------------------------
# Joint hurdle: log-marginal likelihood
# ---------------------------------------------------------------------------
def _log_marginal_leaves_ng(
sum_vr: jax.Array,
sum_vh2: jax.Array,
sum_vr2: jax.Array,
sum_logv: jax.Array,
count: jax.Array,
a_tau: jax.Array,
b_tau: jax.Array,
kappa: jax.Array,
) -> jax.Array:
"""Per-leaf log-marginal likelihood for Normal-Gamma conjugacy.
Model: r_i | beta, gamma ~ N(h_i * beta, 1 / (tau_0 * tau_hat_i))
Prior: beta | gamma ~ N(0, 1/(kappa * gamma))
gamma ~ Gamma(a_tau, b_tau)
Returns per-leaf scalar; sum over active leaves for tree total.
"""
w = sum_vh2
Y_bar = sum_vr / jnp.maximum(w, 1e-10)
SSE = jnp.maximum(sum_vr2 - w * Y_bar ** 2, 0.0)
kappa_post = kappa + w
a_post = a_tau + 0.5 * count
b_post = b_tau + 0.5 * SSE + 0.5 * kappa * w * Y_bar ** 2 / jnp.maximum(kappa_post, 1e-10)
b_post = jnp.maximum(b_post, 1e-10)
lml = (0.5 * sum_logv
- count * 0.5 * jnp.log(2 * jnp.pi)
+ 0.5 * jnp.log(kappa / jnp.maximum(kappa_post, 1e-10))
+ jax.lax.lgamma(a_post)
- a_post * jnp.log(b_post)
+ a_tau * jnp.log(b_tau) - jax.lax.lgamma(a_tau))
return jnp.where(count > 0, lml, 0.0)
def _log_marginal_leaves(
sum_hr: jax.Array,
sum_h2: jax.Array,
sum_r2: jax.Array,
count: jax.Array,
sigma2: jax.Array,
tau_prior: jax.Array,
) -> jax.Array:
"""Per-leaf log-marginal likelihood for Normal-Normal conjugacy.
Model: r_i | theta ~ N(h_i * theta, sigma2)
Prior: theta ~ N(0, tau_prior)
Returns a per-leaf scalar; sum over active leaves for the tree total.
Uses the integrated marginal (theta analytically marginalized):
precision = sum_h2 / sigma2 + 1/tau_prior
lml = -0.5*log(precision*tau_prior)
- sum_r2/(2*sigma2)
+ sum_hr^2 / (2*sigma2^2*precision)
- count/2 * log(2*pi*sigma2)
For leaves with count=0, returns 0 (no data, no contribution).
"""
precision = sum_h2 / sigma2 + 1.0 / tau_prior
log_det = -0.5 * jnp.log(precision * tau_prior)
data_fit = -sum_r2 / (2 * sigma2) + sum_hr ** 2 / (2 * sigma2 ** 2 * precision)
normalizer = -count / 2 * jnp.log(2 * jnp.pi * sigma2)
lml = log_det + data_fit + normalizer
# Zero out leaves with no data (avoid NaN from log(0))
return jnp.where(count > 0, lml, 0.0)
def _log_marginal_leaves_mv(
sum_hr: jax.Array,
sum_h2: jax.Array,
sum_r2: jax.Array,
count: jax.Array,
sigma2: jax.Array,
tau_prior: jax.Array,
) -> jax.Array:
"""Per-leaf log-marginal likelihood for multivariate Normal-Normal conjugacy.
Multivariate analogue of ``_log_marginal_leaves`` for a ``(K-1)``-dimensional
treatment-effect vector. The model at each leaf is:
r_i | theta_ell ~ N(h_i^T theta_ell, sigma2) (scalar outcome)
theta_ell ~ N(0, tau_prior * I_{K-1}) (isotropic prior)
where ``h_i`` is the ``(K-1,)`` reference-contrast basis row for
observation ``i`` and ``theta_ell`` is the leaf-level CATE vector.
Closed-form integrated marginal (theta analytically marginalized):
P = sum_h2 / sigma2 + (1/tau_prior) * I_{K-1} (..., K-1, K-1)
b = sum_hr / sigma2 (..., K-1)
lml = -0.5 * logdet(P)
- 0.5 * (K-1) * log(tau_prior)
- 0.5 * sum_r2 / sigma2
+ 0.5 * b^T P^{-1} b
- 0.5 * count * log(2*pi*sigma2)
At K-1=1 this reduces exactly to the scalar ``_log_marginal_leaves``
formula: ``-0.5*log(P) - 0.5*log(tau_prior) = -0.5*log(P*tau_prior)``.
Broadcasts over arbitrary leading batch dimensions so both single-leaf
``(K-1,)``/``(K-1, K-1)`` and batched ``(L, K-1)``/``(L, K-1, K-1)``
inputs work without vmap.
For leaves with count=0, returns 0 (no data, no contribution).
Parameters
----------
sum_hr : (..., K-1)
Per-leaf basis-weighted residual sum vectors.
sum_h2 : (..., K-1, K-1)
Per-leaf sum of outer products of basis rows.
sum_r2 : (...,)
Per-leaf sum of squared residuals.
count : (...,)
Per-leaf observation count (float; zero-count guard).
sigma2 : scalar
Error variance (1.0 for the probit/conversion channel).
tau_prior : scalar
Isotropic leaf-mean prior variance.
Returns
-------
(...,)
Per-leaf log-marginal; 0.0 where ``count == 0``.
Reference
---------
multi-arm-joint-hurdle-bcf change, batch 5a spec.
"""
K1 = sum_hr.shape[-1]
P = sum_h2 / sigma2 + (1.0 / tau_prior) * jnp.eye(K1) # (..., K-1, K-1)
b = sum_hr / sigma2 # (..., K-1)
_, logdet_P = jnp.linalg.slogdet(P) # (...,)
x = jnp.linalg.solve(P, b[..., None])[..., 0] # P^{-1} b, batched
quad = jnp.sum(b * x, axis=-1) # b^T P^{-1} b
lml = (
-0.5 * logdet_P
- 0.5 * K1 * jnp.log(tau_prior)
- 0.5 * sum_r2 / sigma2
+ 0.5 * quad
- 0.5 * count * jnp.log(2 * jnp.pi * sigma2)
)
return jnp.where(count > 0, lml, 0.0)
def _log_marginal_hurdle_tree(
leaf_indices: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
sigma2_sev: jax.Array,
tau_conv: jax.Array,
tau_sev: jax.Array,
tree_size: int,
n: int,
) -> jax.Array:
"""Total joint log-marginal likelihood for a hurdle tree.
Sums lml_conv + lml_sev over all active leaves.
"""
all_true = jnp.ones(n, dtype=bool)
# Conversion: all visitors contribute, known variance = 1.0 (probit)
ss_conv = _hurdle_sufficient_stats(
leaf_indices, resid_conv, all_true, basis, tree_size,
)
lml_conv = _log_marginal_leaves(*ss_conv, jnp.float32(1.0), tau_conv)
# Severity: only converters contribute
ss_sev = _hurdle_sufficient_stats(
leaf_indices, resid_sev, conv_mask, basis, tree_size,
)
lml_sev = _log_marginal_leaves(*ss_sev, sigma2_sev, tau_sev)
return jnp.sum(lml_conv) + jnp.sum(lml_sev)
def _log_marginal_hurdle_tree_ng(
leaf_indices: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
precision: jax.Array,
tau_conv: jax.Array,
a_tau: jax.Array,
b_tau: jax.Array,
kappa: jax.Array,
tree_size: int,
n: int,
sev_weight: float = 1.0,
) -> jax.Array:
"""Joint log-marginal for a hurdle tree with Normal-Gamma severity.
Conversion channel: unchanged (Normal-Normal, probit sigma²=1).
Severity channel: Normal-Gamma with per-observation precision.
sev_weight: upweight severity LML in split decisions (1.0 = current).
"""
all_true = jnp.ones(n, dtype=bool)
# Conversion: all visitors, known variance = 1.0 (probit)
ss_conv = _hurdle_sufficient_stats(
leaf_indices, resid_conv, all_true, basis, tree_size,
)
lml_conv = _log_marginal_leaves(*ss_conv, jnp.float32(1.0), tau_conv)
# Severity: Normal-Gamma, converters only, precision-weighted
ss_sev = _hurdle_sufficient_stats_ng(
leaf_indices, resid_sev, conv_mask, basis, precision, tree_size,
)
lml_sev = _log_marginal_leaves_ng(*ss_sev, a_tau, b_tau, kappa)
return jnp.sum(lml_conv) + sev_weight * jnp.sum(lml_sev)
def _hurdle_sufficient_stats_joint(
leaf_indices: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
precision: jax.Array,
tree_size: int,
) -> jax.Array:
"""Fused conv+sev sufficient stats per leaf with a single scatter-add.
Fuses ``_hurdle_sufficient_stats`` (4 conv stats) and
``_hurdle_sufficient_stats_ng`` (5 sev stats) into one stack of nine
channels scattered with a single ``_fused_scatter_add`` call. Returns the
raw `(9, tree_size)` array so callers can index by stat:
Index 0..3 — conversion (Normal-Normal, all visitors):
0: sum_hr_c = Σ h_i r_conv_i
1: sum_h2_c = Σ h_i²
2: sum_r2_c = Σ r_conv_i²
3: count_c = n (all visitors contribute)
Index 4..8 — severity (Normal-Gamma, converters only, precision-weighted):
4: sum_vr_s = Σ v_i h_i r_sev_i (v_i = precision_i * mask_i)
5: sum_vh2_s = Σ v_i h_i²
6: sum_vr2_s = Σ v_i r_sev_i²
7: sum_logv_s = Σ log(v_i) (effective: log(precision_i)*mask_i)
8: count_s = Σ mask_i
Halves the scatter-add launch count per tree per move evaluation
relative to the dual ``_hurdle_sufficient_stats*`` calls.
"""
n = leaf_indices.shape[0]
all_true_f = jnp.ones(n, jnp.float32)
mask_f = conv_mask.astype(jnp.float32)
v = precision * mask_f # zero on non-converters by mask
h_sev = basis * mask_f
r_sev_masked = resid_sev * mask_f
log_v = jnp.log(jnp.maximum(precision, 1e-30)) * mask_f
stack = jnp.stack([
# --- conversion channel (mask = all-true) ---
basis * resid_conv, # sum_hr_c
basis * basis, # sum_h2_c
resid_conv * resid_conv, # sum_r2_c
all_true_f, # count_c
# --- severity channel (mask = converter_mask) ---
v * h_sev * resid_sev, # sum_vr_s
v * h_sev * basis, # sum_vh2_s
v * r_sev_masked * resid_sev, # sum_vr2_s
log_v, # sum_logv_s
mask_f, # count_s
]) # (9, n)
return _fused_scatter_add(leaf_indices, stack, tree_size)
def _hurdle_sufficient_stats_joint_nn(
leaf_indices: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
tree_size: int,
) -> jax.Array:
"""Fused conv+sev Normal-Normal sufficient stats per leaf.
Used by the ``per_leaf_gamma=False`` path: severity is Normal-Normal with
sigma² absorbed into the global ``tau_0`` (no per-obs precision weighting),
so the conv (4 stats) and sev (4 stats) sufficient stats can share an
8-channel fused scatter-add.
Index 0..3 — conversion (all visitors):
0: sum_hr_c = Σ h_i r_conv_i
1: sum_h2_c = Σ h_i²
2: sum_r2_c = Σ r_conv_i²
3: count_c = n
Index 4..7 — severity (converters only):
4: sum_hr_s = Σ_{i: conv_i} h_i r_sev_i
5: sum_h2_s = Σ_{i: conv_i} h_i²
6: sum_r2_s = Σ_{i: conv_i} r_sev_i²
7: count_s = Σ conv_mask_i
"""
n = leaf_indices.shape[0]
all_true_f = jnp.ones(n, jnp.float32)
mask_f = conv_mask.astype(jnp.float32)
h_sev = basis * mask_f
r_sev_masked = resid_sev * mask_f
stack = jnp.stack([
# --- conversion channel (all visitors) ---
basis * resid_conv, # sum_hr_c
basis * basis, # sum_h2_c
resid_conv * resid_conv, # sum_r2_c
all_true_f, # count_c
# --- severity channel (converters only) ---
h_sev * resid_sev, # sum_hr_s
h_sev * basis, # sum_h2_s
r_sev_masked * resid_sev, # sum_r2_s
mask_f, # count_s
]) # (8, n)
return _fused_scatter_add(leaf_indices, stack, tree_size)
def _log_marginal_hurdle_ratio_nn(
old_li: jax.Array,
prop_li: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
sigma2_sev: jax.Array,
tau_conv: jax.Array,
tau_sev: jax.Array,
tree_size: int,
sev_weight: float = 1.0,
) -> jax.Array:
"""Log marginal-likelihood ratio with Normal-Normal severity.
Fast-path used by ``per_leaf_gamma=False``: severity is Normal-Normal at
fixed scale ``sigma2_sev = 1/tau_0`` (no per-leaf gamma multiplier and no
per-observation precision weighting). Returns ``lml(prop) - lml(old)``
using one fused 8-channel scatter-add per leaf assignment.
Mathematically equivalent to::
_log_marginal_hurdle_tree(prop_li, ...)
- _log_marginal_hurdle_tree(old_li, ...)
"""
ss_old = _hurdle_sufficient_stats_joint_nn(
old_li, resid_conv, resid_sev, conv_mask, basis, tree_size,
)
ss_prop = _hurdle_sufficient_stats_joint_nn(
prop_li, resid_conv, resid_sev, conv_mask, basis, tree_size,
)
sigma2_one = jnp.float32(1.0)
lml_conv_old = jnp.sum(_log_marginal_leaves(
ss_old[0], ss_old[1], ss_old[2], ss_old[3], sigma2_one, tau_conv,
))
lml_conv_prop = jnp.sum(_log_marginal_leaves(
ss_prop[0], ss_prop[1], ss_prop[2], ss_prop[3], sigma2_one, tau_conv,
))
lml_sev_old = jnp.sum(_log_marginal_leaves(
ss_old[4], ss_old[5], ss_old[6], ss_old[7], sigma2_sev, tau_sev,
))
lml_sev_prop = jnp.sum(_log_marginal_leaves(
ss_prop[4], ss_prop[5], ss_prop[6], ss_prop[7], sigma2_sev, tau_sev,
))
return (
(lml_conv_prop - lml_conv_old)
+ sev_weight * (lml_sev_prop - lml_sev_old)
)
def _hurdle_move_resids_and_counts(
leaf_indices_deep: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
tree_size: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
"""Per-leaf basis-weighted resid sums + leaf counts for one tree.
Single fused 4-channel scatter-add into ``(tree_size,)`` bins, keyed by
the *deeper* leaf-indices for the tree (post-grow / pre-prune). The
bartz-style accept step extracts only ``move.{left, right}`` from these
arrays — O(1) per tree per channel — and forms ``total = left + right``.
Returns
-------
sum_hr_conv, count_conv : (tree_size,) — basis-weighted residual sum and
observation count for the conversion channel (all visitors).
sum_hr_sev, count_sev : (tree_size,) — same but only converters
contribute (severity channel, mask = ``conv_mask``).
"""
n = leaf_indices_deep.shape[0]
all_true_f = jnp.ones(n, jnp.float32)
mask_f = conv_mask.astype(jnp.float32)
h_sev = basis * mask_f
stack = jnp.stack([
basis * resid_conv, # sum_hr_conv
all_true_f, # count_conv
h_sev * resid_sev, # sum_hr_sev (mask absorbed in h_sev)
mask_f, # count_sev
]) # (4, n)
out = _fused_scatter_add(leaf_indices_deep, stack, tree_size)
return out[0], out[1], out[2], out[3]
def _hurdle_move_log_lk_ratio_grow(
sum_hr_left: jax.Array,
sum_hr_right: jax.Array,
sum_hr_total: jax.Array,
count_left: jax.Array,
count_right: jax.Array,
count_total: jax.Array,
sigma2: jax.Array,
tau_prior: jax.Array,
basis_sq: jax.Array,
) -> jax.Array:
"""O(1) Normal-Normal log-lml ratio for one channel of one tree's GROW move.
Returns ``log p(r | tree.split) - log p(r | tree.unsplit)``. For PRUNE
moves negate this externally (bartz pattern: same formula computes
``log_lml(deeper) - log_lml(shallower)``; PRUNE = shallower minus
deeper).
Closed-form Normal-Normal conjugate ratio derived from
``_log_marginal_leaves`` summed at the three involved leaves; the
sum_r2 / count*log(2π σ²) terms drop out because
``leaf_total_sum_r2 == sum_r2_left + sum_r2_right``.
Parameters
----------
sum_hr_{left,right,total}: basis-weighted residual sums in the affected
leaves.
count_{left,right,total}: number of contributing observations.
sigma2: error variance (1.0 for the conversion / probit channel,
``1/tau_0`` for severity on the per-leaf-gamma=False path).
tau_prior: leaf-mean prior variance.
basis_sq: ``basis²`` per observation, treated here as a constant. For
±0.5 (tau forest) and 1.0 (mu forest) the basis is uniform over
contributing observations, so ``Σ h_i² = count × basis_sq``.
"""
sum_h2_left = count_left * basis_sq
sum_h2_right = count_right * basis_sq
sum_h2_total = count_total * basis_sq
T_left = sigma2 + sum_h2_left * tau_prior
T_right = sigma2 + sum_h2_right * tau_prior
T_total = sigma2 + sum_h2_total * tau_prior
log_det_diff = 0.5 * jnp.log(
sigma2 * T_total / jnp.maximum(T_left * T_right, 1e-30)
)
data_fit_diff = (tau_prior / (2.0 * sigma2)) * (
sum_hr_left * sum_hr_left / jnp.maximum(T_left, 1e-30)
+ sum_hr_right * sum_hr_right / jnp.maximum(T_right, 1e-30)
- sum_hr_total * sum_hr_total / jnp.maximum(T_total, 1e-30)
)
return log_det_diff + data_fit_diff
def _mv_leaf_draw(
sum_hr: jax.Array,
sum_h2: jax.Array,
sigma2: jax.Array,
tau_prior: jax.Array,
noise: jax.Array,
) -> jax.Array:
"""Conjugate posterior mean + correlated MV-normal draw for one leaf.
Per-leaf Bayesian update for the multivariate Normal-Normal model:
r_i | theta ~ N(h_i^T theta, sigma2)
theta ~ N(0, tau_prior * I_{K-1})
The posterior is N(mean, P^{-1}) where:
P = sum_h2 / sigma2 + (1/tau_prior) * I_{K-1}
mean = P^{-1} @ (sum_hr / sigma2)
A sample is generated as:
draw = mean + M @ noise
where ``M = chol(P^{-1})`` (lower-triangular Cholesky factor of the
posterior covariance), so ``cov(draw) = M M^T = P^{-1}``.
At K-1=1 the draw reduces to the scalar ``mean + noise / sqrt(P)``.
Broadcasts over arbitrary leading batch dimensions (batched
``jnp.linalg.cholesky`` and ``jnp.linalg.solve`` handle leading dims).
The caller is responsible for zeroing draws for empty leaves.
Parameters
----------
sum_hr : (..., K-1)
Per-leaf basis-weighted residual sum vectors.
sum_h2 : (..., K-1, K-1)
Per-leaf sum of outer products of basis rows.
sigma2 : scalar
Error variance (1.0 for the probit/conversion channel).
tau_prior : scalar
Isotropic leaf-mean prior variance.
noise : (..., K-1)
Standard normal noise for each leaf (same shape as ``sum_hr``).
Returns
-------
(..., K-1)
Posterior draw for each leaf.
Reference
---------
multi-arm-joint-hurdle-bcf change, batch 5a spec.
"""
K1 = sum_hr.shape[-1]
P = sum_h2 / sigma2 + (1.0 / tau_prior) * jnp.eye(K1)
b = sum_hr / sigma2
mean = jnp.linalg.solve(P, b[..., None])[..., 0] # P^{-1} (sum_hr/sigma2)
M = jnp.linalg.cholesky(jnp.linalg.inv(P)) # lower; M M^T = P^{-1}
draw = mean + (M @ noise[..., None])[..., 0] # cov(draw) = M M^T = P^{-1}
return draw
def _mv_log_lk_ratio_grow(
sum_hr_left: jax.Array,
sum_h2_left: jax.Array,
sum_hr_right: jax.Array,
sum_h2_right: jax.Array,
sigma2: jax.Array,
tau_prior: jax.Array,
) -> jax.Array:
"""O(1) multivariate Normal-Normal log-lml ratio for one channel's GROW move.
Multivariate analogue of ``_hurdle_move_log_lk_ratio_grow`` for a
``(K-1)``-dimensional treatment-effect vector. Returns:
log p(r | split) - log p(r | unsplit)
computed in O(1) from the left- and right-child sufficient statistics
without re-evaluating all leaves. For PRUNE moves negate externally
(same pattern as the scalar analogue).
The derivation uses:
M(h2) = tau_prior * h2 + sigma2 * I_{K-1}
with the sum_r2 and count*log(2π σ²) terms cancelling in the split
difference (left SSE + right SSE = total SSE identically):
logdet_diff = 0.5 * (logdet(M_T) - logdet(M_L) - logdet(M_R))
+ 0.5 * (K-1) * log(sigma2)
datafit_diff = 0.5 * (tau_prior/sigma2) * (
hr_L^T M_L^{-1} hr_L
+ hr_R^T M_R^{-1} hr_R
- hr_T^T M_T^{-1} hr_T
)
where ``hr_T = hr_L + hr_R``, ``h2_T = h2_L + h2_R``.
At K-1=1 this is algebraically identical to the scalar
``_hurdle_move_log_lk_ratio_grow`` with ``basis_sq = sum_h2 / count``.
Broadcasts over leading batch dimensions; works for both single-leaf
``(K-1,)``/``(K-1, K-1)`` and batched ``(L, K-1)``/``(L, K-1, K-1)``
inputs without vmap.
Parameters
----------
sum_hr_left : (..., K-1)
Basis-weighted residual sum for the left child leaf.
sum_h2_left : (..., K-1, K-1)
Sum of outer products for the left child leaf.
sum_hr_right : (..., K-1)
Basis-weighted residual sum for the right child leaf.
sum_h2_right : (..., K-1, K-1)
Sum of outer products for the right child leaf.
sigma2 : scalar
Error variance (1.0 for the probit/conversion channel).
tau_prior : scalar
Isotropic leaf-mean prior variance.
Returns
-------
(...,)
Log-likelihood ratio for the grow move.
Reference
---------
multi-arm-joint-hurdle-bcf change, batch 5a spec.
"""
K1 = sum_hr_left.shape[-1]
sum_hr_total = sum_hr_left + sum_hr_right
sum_h2_total = sum_h2_left + sum_h2_right
def _M(h2: jax.Array) -> jax.Array:
return tau_prior * h2 + sigma2 * jnp.eye(K1)
def _logdet(h2: jax.Array) -> jax.Array:
_, ld = jnp.linalg.slogdet(_M(h2))
return ld
def _quad(hr: jax.Array, h2: jax.Array) -> jax.Array:
x = jnp.linalg.solve(_M(h2), hr[..., None])[..., 0]
return jnp.sum(hr * x, axis=-1)
logdet_diff = (
0.5 * (_logdet(sum_h2_total) - _logdet(sum_h2_left) - _logdet(sum_h2_right))
+ 0.5 * K1 * jnp.log(sigma2)
)
datafit_diff = 0.5 * (tau_prior / sigma2) * (
_quad(sum_hr_left, sum_h2_left)
+ _quad(sum_hr_right, sum_h2_right)
- _quad(sum_hr_total, sum_h2_total)
)
return logdet_diff + datafit_diff
def _log_marginal_hurdle_ratio_ng(
old_li: jax.Array,
prop_li: jax.Array,
resid_conv: jax.Array,
resid_sev: jax.Array,
conv_mask: jax.Array,
basis: jax.Array,
precision: jax.Array,
tau_conv: jax.Array,
a_tau: jax.Array,
b_tau: jax.Array,
kappa: jax.Array,
tree_size: int,
sev_weight: float = 1.0,
) -> jax.Array:
"""Log marginal-likelihood ratio between proposed and current tree structure.
Mathematically equivalent to::
_log_marginal_hurdle_tree_ng(prop_li, ...)
- _log_marginal_hurdle_tree_ng(old_li, ...)
but each tree-evaluation uses a single fused scatter-add over both
conversion and severity channels (9 stats), halving the per-tree
scatter-add launch count vs the dual ``_log_marginal_hurdle_tree_ng``
invocation. Returns a scalar log-ratio suitable for an MH accept step.
"""
ss_old = _hurdle_sufficient_stats_joint(
old_li, resid_conv, resid_sev, conv_mask, basis, precision, tree_size,
)
ss_prop = _hurdle_sufficient_stats_joint(
prop_li, resid_conv, resid_sev, conv_mask, basis, precision, tree_size,
)
sigma2_one = jnp.float32(1.0)
lml_conv_old = jnp.sum(_log_marginal_leaves(
ss_old[0], ss_old[1], ss_old[2], ss_old[3], sigma2_one, tau_conv,
))
lml_conv_prop = jnp.sum(_log_marginal_leaves(
ss_prop[0], ss_prop[1], ss_prop[2], ss_prop[3], sigma2_one, tau_conv,
))
lml_sev_old = jnp.sum(_log_marginal_leaves_ng(
ss_old[4], ss_old[5], ss_old[6], ss_old[7], ss_old[8],
a_tau, b_tau, kappa,
))
lml_sev_prop = jnp.sum(_log_marginal_leaves_ng(
ss_prop[4], ss_prop[5], ss_prop[6], ss_prop[7], ss_prop[8],
a_tau, b_tau, kappa,
))
return (
(lml_conv_prop - lml_conv_old)
+ sev_weight * (lml_sev_prop - lml_sev_old)
)