Source code for pytyche.bcf.lml

"""Pure-math sufficient statistics and log-marginal likelihoods for the GPU BCF.

This module holds the closed-form pieces of the joint hurdle BCF posterior:
basis-weighted sufficient statistics aggregated per leaf, the matching
log-marginal likelihoods (Normal-Normal and Normal-Gamma conjugacy), and the
moment-matched gamma prior used to set ``(a_tau, b_tau)``.

It is "pure math" in the sense that every function here depends only on its
arguments and on JAX/numpy/scipy primitives — no module-level state, no GPU
device handles, no JIT-compiled kernels of its own.

Import graph
------------
This module imports ``_fused_scatter_add`` from ``pytyche.bcf.kernels`` and
nothing else from sibling GPU BCF modules. The orchestrator and the GFR
module import FROM here, never the other way around — see
``pytyche.bcf.kernels`` for the same convention applied at the bottom of the
import graph.

Contents
--------
``_solve_gamma_prior`` — Linero moment-matching for ``(a_gamma, b_gamma)``.
``_hurdle_sufficient_stats`` — basis-weighted (Normal-Normal) per-leaf stats.
``_hurdle_sufficient_stats_ng`` — precision-weighted (Normal-Gamma) per-leaf stats.
``_log_marginal_leaves`` — per-leaf log-marginal under Normal-Normal conjugacy.
``_log_marginal_leaves_ng`` — per-leaf log-marginal under Normal-Gamma conjugacy.
``_log_marginal_hurdle_tree`` — total joint log-marginal for a hurdle tree (NN severity).
``_log_marginal_hurdle_tree_ng`` — total joint log-marginal for a hurdle tree (NG severity).
"""

from __future__ import annotations

from typing import NamedTuple, cast

import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import brentq
from scipy.special import digamma, polygamma

from pytyche.bcf.kernels import _fused_scatter_add

# ---------------------------------------------------------------------------
# Per-iteration parallel-stage factors (mirrors bartz's PreLkV / PreLk / PreLf)
# ---------------------------------------------------------------------------


[docs] class HurdlePreLkv(NamedTuple): """Per-tree likelihood-ratio factors for one channel of the hurdle lml. Mirrors bartz's ``PreLkV`` Module as a JAX-pytree NamedTuple. All fields have shape ``(num_trees,)`` after the parallel pre-compute stage. Fields T_left/T_right/T_total are the conjugate-Normal-Normal denominators ``sigma2 + sum_h2 * tau_prior`` evaluated at the move's left, right, and parent leaves; log_sqrt_term is the log-determinant ratio ``0.5 * log(sigma2 * T_total / (T_left * T_right))``; exp_factor is the shared scalar ``tau_prior / (2 * sigma2)`` that multiplies the residual quadratic in the O(1) ratio. """ T_left: jax.Array T_right: jax.Array T_total: jax.Array log_sqrt_term: jax.Array exp_factor: jax.Array
[docs] class HurdlePreLf(NamedTuple): """Per-tree leaf-posterior factors for one channel. Mirrors bartz's ``PreLf``. For each tree's tree_size leaves: new_leaf_value[t, l] = mean_factor[t, l] * sum_resid_at_leaf[t, l] + centered_leaves[t, l] Both fields zeroed where ``count_trees[t, l] == 0`` so empty leaves get zero alpha (matches the existing ``jnp.where(count > 0, ..., 0)`` pattern). """ mean_factor: jax.Array # (num_trees, tree_size) centered_leaves: jax.Array # (num_trees, tree_size)
def _precompute_lk_one_channel( count_left: jax.Array, count_right: jax.Array, sigma2: jax.Array, tau_prior: jax.Array, basis_sq: jax.Array, ) -> HurdlePreLkv: """Build a ``HurdlePreLkv`` for one channel from per-tree leaf counts. All inputs except ``sigma2``, ``tau_prior``, ``basis_sq`` are ``(num_trees,)``. ``basis_sq`` is the uniform per-observation ``basis²`` (1.0 for the mu forest, 0.25 for the tau forest's ±0.5 basis). """ sum_h2_L = count_left * basis_sq sum_h2_R = count_right * basis_sq sum_h2_T = sum_h2_L + sum_h2_R T_L = sigma2 + sum_h2_L * tau_prior T_R = sigma2 + sum_h2_R * tau_prior T_T = sigma2 + sum_h2_T * tau_prior log_sqrt_term = 0.5 * jnp.log(sigma2 * T_T / jnp.maximum(T_L * T_R, 1e-30)) exp_factor = tau_prior / (2.0 * sigma2) return HurdlePreLkv( T_left=T_L, T_right=T_R, T_total=T_T, log_sqrt_term=log_sqrt_term, exp_factor=exp_factor, ) def _precompute_leaf_one_channel( count_trees: jax.Array, sigma2: jax.Array, tau_prior: jax.Array, basis_sq: jax.Array, noise: jax.Array, ) -> HurdlePreLf: """Build a ``HurdlePreLf`` for one channel from per-tree per-leaf counts. The Normal-Normal posterior at a leaf with sufficient stat sum_hr and integrated count is:: precision_post = sum_h2 / sigma2 + 1 / tau_prior = T / (sigma2 * tau_prior) var_post = sigma2 * tau_prior / T mean_factor = var_post / sigma2 = tau_prior / T new_alpha = mean_factor * sum_hr + noise * sqrt(var_post) For uninhabited leaves (``count == 0``) we zero both factors so the sampled alpha is exactly 0, matching the existing ``jnp.where(count > 0, ..., 0.0)`` semantics. """ sum_h2 = count_trees * basis_sq T = sigma2 + sum_h2 * tau_prior T_safe = jnp.maximum(T, 1e-30) var_post = sigma2 * tau_prior / T_safe populated = count_trees > 0 mean_factor = jnp.where(populated, tau_prior / T_safe, 0.0) centered = jnp.where(populated, noise * jnp.sqrt(var_post), 0.0) return HurdlePreLf(mean_factor=mean_factor, centered_leaves=centered) def _precompute_likelihood_terms_hurdle( count_trees_conv: jax.Array, count_trees_sev: jax.Array, moves_left: jax.Array, moves_right: jax.Array, tau_conv: jax.Array, tau_sev: jax.Array, sigma2_sev: jax.Array, basis_sq: jax.Array, ) -> tuple[HurdlePreLkv, HurdlePreLkv]: """Pre-compute per-tree likelihood-ratio factors for both channels. Hoisted out of the sequential per-tree scan: T_left, T_right, T_total, log_sqrt_term, and exp_factor depend only on the proposal structure (counts at move.left / move.right) and the channel hyperparameters, not on residuals. Computed once per iteration in the parallel pre-compute stage; the per-tree sequential body then reads scalars per tree. Conversion channel: Normal-Normal with sigma² = 1.0. Severity channel: Normal-Normal with sigma² = 1/tau_0 (the ``per_leaf_gamma=False`` path). The Normal-Gamma severity path keeps its sequential dependency and does not consume these factors. Parameters ---------- count_trees_conv, count_trees_sev : (num_trees, tree_size) Leaf counts (channel-specific mask) under the deeper-of-old-vs-prop leaf-index assignment per tree. moves_left, moves_right : (num_trees,) int Heap indices of the two children involved in each tree's move. tau_conv, tau_sev : scalars — leaf-mean prior variances per channel. sigma2_sev : scalar — severity error variance ``1/tau_0`` for the NN path. basis_sq : scalar — uniform ``basis²`` (1.0 mu / 0.25 tau). """ sigma2_conv = jnp.float32(1.0) count_left_conv = jnp.take_along_axis( count_trees_conv, moves_left[:, None], axis=1, ).squeeze(1) count_right_conv = jnp.take_along_axis( count_trees_conv, moves_right[:, None], axis=1, ).squeeze(1) prelkv_conv = _precompute_lk_one_channel( count_left_conv, count_right_conv, sigma2_conv, tau_conv, basis_sq, ) count_left_sev = jnp.take_along_axis( count_trees_sev, moves_left[:, None], axis=1, ).squeeze(1) count_right_sev = jnp.take_along_axis( count_trees_sev, moves_right[:, None], axis=1, ).squeeze(1) prelkv_sev = _precompute_lk_one_channel( count_left_sev, count_right_sev, sigma2_sev, tau_sev, basis_sq, ) return prelkv_conv, prelkv_sev def _precompute_leaf_terms_hurdle( count_trees_conv_merged: jax.Array, count_trees_sev_merged: jax.Array, tau_conv: jax.Array, tau_sev: jax.Array, sigma2_sev: jax.Array, basis_sq: jax.Array, noise_alpha: jax.Array, noise_beta: jax.Array, ) -> tuple[HurdlePreLf, HurdlePreLf]: """Pre-compute per-tree per-leaf posterior factors for both channels. Operates on count_trees AFTER the per-tree merge at ``move.node``: the parent-leaf bin holds ``count_left + count_right`` so that leaf sampling at move.node uses the count_total posterior (matches the existing sequential-body behaviour at the rejected-grow / accepted-prune cases). Returns ``(prelf_alpha, prelf_beta)``. Factors at empty leaves are zeroed, so ``new_leaf = mean_factor * sum_hr + centered_leaves`` reduces to zero there, matching ``jnp.where(count > 0, ..., 0.0)`` semantics. """ sigma2_conv = jnp.float32(1.0) prelf_alpha = _precompute_leaf_one_channel( count_trees_conv_merged, sigma2_conv, tau_conv, basis_sq, noise_alpha, ) prelf_beta = _precompute_leaf_one_channel( count_trees_sev_merged, sigma2_sev, tau_sev, basis_sq, noise_beta, ) return prelf_alpha, prelf_beta def _hurdle_resid_sums( leaf_indices: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, tree_size: int, ) -> tuple[jax.Array, jax.Array]: """Per-leaf basis-weighted residual sums for both channels (one tree). Single 2-channel fused scatter-add — the only residual-dependent operation left in the sequential per-tree body once ``_precompute_likelihood_terms_hurdle`` and ``_precompute_leaf_terms_hurdle`` have hoisted the structure-only factors into Stage 1. Returns ``(sum_hr_conv, sum_hr_sev)`` each of shape ``(tree_size,)``. Counts are NOT scattered here; they live in the parallel-stage ``count_trees_*`` arrays. """ mask_f = conv_mask.astype(jnp.float32) h_sev = basis * mask_f stack = jnp.stack([ basis * resid_conv, # sum_hr_conv (all visitors weighted by basis) h_sev * resid_sev, # sum_hr_sev (mask × basis × resid) ]) # (2, n) out = _fused_scatter_add(leaf_indices, stack, tree_size) return out[0], out[1] def _hurdle_resid_sums_mv( leaf_indices: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, tree_size: int, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: """Per-leaf first- and second-order statistics for the multivariate hurdle (one tree). Multivariate analogue of ``_hurdle_resid_sums`` for a ``(n, K-1)`` basis matrix. Per leaf L, accumulates four statistics over observations v in L: First-order ``(K-1,)`` per leaf: sum_hr_conv[L] = Σ_{v in L} basis_v * resid_conv_v sum_hr_sev[L] = Σ_{v in L} mask_v * basis_v * resid_sev_v Second-order ``(K-1, K-1)`` per leaf (outer products): sum_h2_conv[L] = Σ_{v in L} outer(basis_v, basis_v) sum_h2_sev[L] = Σ_{v in L} mask_v * outer(basis_v, basis_v) ``conv`` stats accumulate ALL visitors; ``sev`` stats accumulate converters only (``conv_mask == True``). ``K-1`` is derived from ``basis.shape[1]`` so the function is valid for any K ≥ 2 (including K=2, i.e. K-1=1). Parameters ---------- leaf_indices : (n,) int Leaf index for each observation. resid_conv : (n,) float32 Residuals for the conversion channel (all visitors). resid_sev : (n,) float32 Residuals for the severity channel (used only for converters). conv_mask : (n,) bool True for converting observations. basis : (n, K-1) float32 Reference-contrast basis matrix. tree_size : int Total nodes in the tree heap. Returns ------- sum_hr_conv : (tree_size, K-1) sum_hr_sev : (tree_size, K-1) sum_h2_conv : (tree_size, K-1, K-1) sum_h2_sev : (tree_size, K-1, K-1) """ mask_f = conv_mask.astype(jnp.float32) # (n,) # First-order stats: shape (tree_size, K-1) sum_hr_conv = ( jnp.zeros((tree_size, basis.shape[1]), dtype=jnp.float32) .at[leaf_indices] .add(basis * resid_conv[:, None]) ) sum_hr_sev = ( jnp.zeros((tree_size, basis.shape[1]), dtype=jnp.float32) .at[leaf_indices] .add(basis * (mask_f * resid_sev)[:, None]) ) # Second-order stats (outer products): shape (tree_size, K-1, K-1) sum_h2_conv = ( jnp.zeros((tree_size, basis.shape[1], basis.shape[1]), dtype=jnp.float32) .at[leaf_indices] .add(basis[:, :, None] * basis[:, None, :]) ) sum_h2_sev = ( jnp.zeros((tree_size, basis.shape[1], basis.shape[1]), dtype=jnp.float32) .at[leaf_indices] .add(mask_f[:, None, None] * basis[:, :, None] * basis[:, None, :]) ) return sum_hr_conv, sum_hr_sev, sum_h2_conv, sum_h2_sev def _hurdle_count_trees( leaf_indices: jax.Array, conv_mask: jax.Array, tree_size: int, ) -> tuple[jax.Array, jax.Array]: """Per-leaf observation counts, conv (all) and sev (converters), one tree. Structure-only — depends on leaf assignments and the converter mask but NOT on residuals. Lives in Stage 1's parallel pre-compute (vmap'd across trees) so the sequential body never recomputes it. Returns ``(count_conv, count_sev)`` each of shape ``(tree_size,)``. """ n = leaf_indices.shape[0] all_true_f = jnp.ones(n, jnp.float32) mask_f = conv_mask.astype(jnp.float32) stack = jnp.stack([all_true_f, mask_f]) # (2, n) out = _fused_scatter_add(leaf_indices, stack, tree_size) return out[0], out[1] def _log_marginal_hurdle_ratio( sum_hr_left_conv: jax.Array, sum_hr_right_conv: jax.Array, sum_hr_total_conv: jax.Array, sum_hr_left_sev: jax.Array, sum_hr_right_sev: jax.Array, sum_hr_total_sev: jax.Array, prelkv_conv: HurdlePreLkv, prelkv_sev: HurdlePreLkv, sev_weight: float = 1.0, ) -> jax.Array: """O(1) per-tree joint log-lml ratio (proposed minus current, GROW direction). Mirrors bartz's ``compute_likelihood_ratio_uv`` (``bartz-ref/src/bartz/mcmcstep/_step.py:1187-1199``) extended to the hurdle two-channel model. Operates only on the move's three involved leaves per tree per channel — no per-leaf sums over ``tree_size``. For PRUNE moves negate externally (``log_lml(deep) - log_lml(shallow)`` becomes ``-result`` for the shallow → deep direction). Inputs are scalars per tree (extracted from the sequential body's per-tree sum_hr arrays at ``move.left/right`` and their sum); prelkv values are also scalars per tree from the parallel pre-compute. """ log_lk_conv = prelkv_conv.log_sqrt_term + prelkv_conv.exp_factor * ( sum_hr_left_conv * sum_hr_left_conv / prelkv_conv.T_left + sum_hr_right_conv * sum_hr_right_conv / prelkv_conv.T_right - sum_hr_total_conv * sum_hr_total_conv / prelkv_conv.T_total ) log_lk_sev = prelkv_sev.log_sqrt_term + prelkv_sev.exp_factor * ( sum_hr_left_sev * sum_hr_left_sev / prelkv_sev.T_left + sum_hr_right_sev * sum_hr_right_sev / prelkv_sev.T_right - sum_hr_total_sev * sum_hr_total_sev / prelkv_sev.T_total ) return log_lk_conv + sev_weight * log_lk_sev # --------------------------------------------------------------------------- # Joint hurdle: sufficient statistics (scatter-add into leaf-indexed arrays) # --------------------------------------------------------------------------- def _hurdle_sufficient_stats( leaf_indices: jax.Array, resid: jax.Array, mask: jax.Array, basis: jax.Array, tree_size: int, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: """Compute basis-weighted sufficient statistics per leaf node. For each leaf, accumulates: sum_hr = sum(h_i * r_i) — basis-weighted residual sum_h2 = sum(h_i^2) — basis precision sum_r2 = sum(r_i^2) — raw SSE (for log-marginal) count = sum(mask_i) — number of contributing observations Parameters ---------- leaf_indices : (n,) which leaf each observation falls into resid : (n,) residuals for this channel mask : (n,) bool — which observations contribute (all True for conversion, converter_mask for severity) basis : (n,) basis weights (1.0 for mu forest, b0/b1 for tau forest) tree_size : total nodes in the tree heap (2^d) Returns ------- (sum_hr, sum_h2, sum_r2, count) each of shape (tree_size,) """ mask_f = mask.astype(jnp.float32) r_masked = resid * mask_f h_masked = basis * mask_f stack = jnp.stack([ h_masked * resid, # sum_hr h_masked * basis, # sum_h2 r_masked * resid, # sum_r2 mask_f, # count ]) # (4, n) results = _fused_scatter_add(leaf_indices, stack, tree_size) # (4, tree_size) return results[0], results[1], results[2], results[3] # --------------------------------------------------------------------------- # Gamma prior moment-matching (Linero et al. 2020) # --------------------------------------------------------------------------- def _solve_gamma_prior( num_trees_total: int, var_tau: float = 0.5, ) -> tuple[float, float]: """Solve for (a_gamma, b_gamma) via Linero's moment-matching. Target: trigamma(a) = var_tau / num_trees_total Then: b = exp(digamma(a)) [ensures E[log(gamma)] = 0] Uses brentq on f(a) = trigamma(a) - target (monotone decreasing, unique root). """ target_var = var_tau / num_trees_total def f(a): return float(polygamma(1, a) - target_var) # brentq is overloaded float|tuple based on full_output; cast for pyright. a_gamma = float(cast(float, brentq(f, 1e-4, 1e6))) b_gamma = float(np.exp(digamma(a_gamma))) return a_gamma, b_gamma # --------------------------------------------------------------------------- # Joint hurdle: precision-weighted sufficient statistics (Normal-Gamma) # --------------------------------------------------------------------------- def _hurdle_sufficient_stats_ng( leaf_indices: jax.Array, resid: jax.Array, mask: jax.Array, basis: jax.Array, precision: jax.Array, tree_size: int, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: """Precision-weighted sufficient statistics per leaf for Normal-Gamma model. For each leaf, accumulates: sum_vr = Σ(v_i * h_i * r_i) — precision-weighted basis×residual sum_vh2 = Σ(v_i * h_i²) — precision-weighted basis² sum_vr2 = Σ(v_i * r_i²) — precision-weighted residual² (SSE) sum_logv= Σ(log(v_i)) — log-precision (LogLT normalizer) count = Σ(mask_i) Where v_i = tau_0 * tau_hat_i (per-observation precision). """ mask_f = mask.astype(jnp.float32) v = precision * mask_f h = basis * mask_f r = resid * mask_f stack = jnp.stack([ v * h * resid, # sum_vr v * h * basis, # sum_vh2 (note: h already has mask, so v*h*basis = v*mask*h²) v * r * resid, # sum_vr2 jnp.log(jnp.maximum(precision, 1e-30)) * mask_f, # sum_logv mask_f, # count ]) # (5, n) results = _fused_scatter_add(leaf_indices, stack, tree_size) # (5, tree_size) return results[0], results[1], results[2], results[3], results[4] # --------------------------------------------------------------------------- # Joint hurdle: log-marginal likelihood # --------------------------------------------------------------------------- def _log_marginal_leaves_ng( sum_vr: jax.Array, sum_vh2: jax.Array, sum_vr2: jax.Array, sum_logv: jax.Array, count: jax.Array, a_tau: jax.Array, b_tau: jax.Array, kappa: jax.Array, ) -> jax.Array: """Per-leaf log-marginal likelihood for Normal-Gamma conjugacy. Model: r_i | beta, gamma ~ N(h_i * beta, 1 / (tau_0 * tau_hat_i)) Prior: beta | gamma ~ N(0, 1/(kappa * gamma)) gamma ~ Gamma(a_tau, b_tau) Returns per-leaf scalar; sum over active leaves for tree total. """ w = sum_vh2 Y_bar = sum_vr / jnp.maximum(w, 1e-10) SSE = jnp.maximum(sum_vr2 - w * Y_bar ** 2, 0.0) kappa_post = kappa + w a_post = a_tau + 0.5 * count b_post = b_tau + 0.5 * SSE + 0.5 * kappa * w * Y_bar ** 2 / jnp.maximum(kappa_post, 1e-10) b_post = jnp.maximum(b_post, 1e-10) lml = (0.5 * sum_logv - count * 0.5 * jnp.log(2 * jnp.pi) + 0.5 * jnp.log(kappa / jnp.maximum(kappa_post, 1e-10)) + jax.lax.lgamma(a_post) - a_post * jnp.log(b_post) + a_tau * jnp.log(b_tau) - jax.lax.lgamma(a_tau)) return jnp.where(count > 0, lml, 0.0) def _log_marginal_leaves( sum_hr: jax.Array, sum_h2: jax.Array, sum_r2: jax.Array, count: jax.Array, sigma2: jax.Array, tau_prior: jax.Array, ) -> jax.Array: """Per-leaf log-marginal likelihood for Normal-Normal conjugacy. Model: r_i | theta ~ N(h_i * theta, sigma2) Prior: theta ~ N(0, tau_prior) Returns a per-leaf scalar; sum over active leaves for the tree total. Uses the integrated marginal (theta analytically marginalized): precision = sum_h2 / sigma2 + 1/tau_prior lml = -0.5*log(precision*tau_prior) - sum_r2/(2*sigma2) + sum_hr^2 / (2*sigma2^2*precision) - count/2 * log(2*pi*sigma2) For leaves with count=0, returns 0 (no data, no contribution). """ precision = sum_h2 / sigma2 + 1.0 / tau_prior log_det = -0.5 * jnp.log(precision * tau_prior) data_fit = -sum_r2 / (2 * sigma2) + sum_hr ** 2 / (2 * sigma2 ** 2 * precision) normalizer = -count / 2 * jnp.log(2 * jnp.pi * sigma2) lml = log_det + data_fit + normalizer # Zero out leaves with no data (avoid NaN from log(0)) return jnp.where(count > 0, lml, 0.0) def _log_marginal_leaves_mv( sum_hr: jax.Array, sum_h2: jax.Array, sum_r2: jax.Array, count: jax.Array, sigma2: jax.Array, tau_prior: jax.Array, ) -> jax.Array: """Per-leaf log-marginal likelihood for multivariate Normal-Normal conjugacy. Multivariate analogue of ``_log_marginal_leaves`` for a ``(K-1)``-dimensional treatment-effect vector. The model at each leaf is: r_i | theta_ell ~ N(h_i^T theta_ell, sigma2) (scalar outcome) theta_ell ~ N(0, tau_prior * I_{K-1}) (isotropic prior) where ``h_i`` is the ``(K-1,)`` reference-contrast basis row for observation ``i`` and ``theta_ell`` is the leaf-level CATE vector. Closed-form integrated marginal (theta analytically marginalized): P = sum_h2 / sigma2 + (1/tau_prior) * I_{K-1} (..., K-1, K-1) b = sum_hr / sigma2 (..., K-1) lml = -0.5 * logdet(P) - 0.5 * (K-1) * log(tau_prior) - 0.5 * sum_r2 / sigma2 + 0.5 * b^T P^{-1} b - 0.5 * count * log(2*pi*sigma2) At K-1=1 this reduces exactly to the scalar ``_log_marginal_leaves`` formula: ``-0.5*log(P) - 0.5*log(tau_prior) = -0.5*log(P*tau_prior)``. Broadcasts over arbitrary leading batch dimensions so both single-leaf ``(K-1,)``/``(K-1, K-1)`` and batched ``(L, K-1)``/``(L, K-1, K-1)`` inputs work without vmap. For leaves with count=0, returns 0 (no data, no contribution). Parameters ---------- sum_hr : (..., K-1) Per-leaf basis-weighted residual sum vectors. sum_h2 : (..., K-1, K-1) Per-leaf sum of outer products of basis rows. sum_r2 : (...,) Per-leaf sum of squared residuals. count : (...,) Per-leaf observation count (float; zero-count guard). sigma2 : scalar Error variance (1.0 for the probit/conversion channel). tau_prior : scalar Isotropic leaf-mean prior variance. Returns ------- (...,) Per-leaf log-marginal; 0.0 where ``count == 0``. Reference --------- multi-arm-joint-hurdle-bcf change, batch 5a spec. """ K1 = sum_hr.shape[-1] P = sum_h2 / sigma2 + (1.0 / tau_prior) * jnp.eye(K1) # (..., K-1, K-1) b = sum_hr / sigma2 # (..., K-1) _, logdet_P = jnp.linalg.slogdet(P) # (...,) x = jnp.linalg.solve(P, b[..., None])[..., 0] # P^{-1} b, batched quad = jnp.sum(b * x, axis=-1) # b^T P^{-1} b lml = ( -0.5 * logdet_P - 0.5 * K1 * jnp.log(tau_prior) - 0.5 * sum_r2 / sigma2 + 0.5 * quad - 0.5 * count * jnp.log(2 * jnp.pi * sigma2) ) return jnp.where(count > 0, lml, 0.0) def _log_marginal_hurdle_tree( leaf_indices: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, sigma2_sev: jax.Array, tau_conv: jax.Array, tau_sev: jax.Array, tree_size: int, n: int, ) -> jax.Array: """Total joint log-marginal likelihood for a hurdle tree. Sums lml_conv + lml_sev over all active leaves. """ all_true = jnp.ones(n, dtype=bool) # Conversion: all visitors contribute, known variance = 1.0 (probit) ss_conv = _hurdle_sufficient_stats( leaf_indices, resid_conv, all_true, basis, tree_size, ) lml_conv = _log_marginal_leaves(*ss_conv, jnp.float32(1.0), tau_conv) # Severity: only converters contribute ss_sev = _hurdle_sufficient_stats( leaf_indices, resid_sev, conv_mask, basis, tree_size, ) lml_sev = _log_marginal_leaves(*ss_sev, sigma2_sev, tau_sev) return jnp.sum(lml_conv) + jnp.sum(lml_sev) def _log_marginal_hurdle_tree_ng( leaf_indices: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, precision: jax.Array, tau_conv: jax.Array, a_tau: jax.Array, b_tau: jax.Array, kappa: jax.Array, tree_size: int, n: int, sev_weight: float = 1.0, ) -> jax.Array: """Joint log-marginal for a hurdle tree with Normal-Gamma severity. Conversion channel: unchanged (Normal-Normal, probit sigma²=1). Severity channel: Normal-Gamma with per-observation precision. sev_weight: upweight severity LML in split decisions (1.0 = current). """ all_true = jnp.ones(n, dtype=bool) # Conversion: all visitors, known variance = 1.0 (probit) ss_conv = _hurdle_sufficient_stats( leaf_indices, resid_conv, all_true, basis, tree_size, ) lml_conv = _log_marginal_leaves(*ss_conv, jnp.float32(1.0), tau_conv) # Severity: Normal-Gamma, converters only, precision-weighted ss_sev = _hurdle_sufficient_stats_ng( leaf_indices, resid_sev, conv_mask, basis, precision, tree_size, ) lml_sev = _log_marginal_leaves_ng(*ss_sev, a_tau, b_tau, kappa) return jnp.sum(lml_conv) + sev_weight * jnp.sum(lml_sev) def _hurdle_sufficient_stats_joint( leaf_indices: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, precision: jax.Array, tree_size: int, ) -> jax.Array: """Fused conv+sev sufficient stats per leaf with a single scatter-add. Fuses ``_hurdle_sufficient_stats`` (4 conv stats) and ``_hurdle_sufficient_stats_ng`` (5 sev stats) into one stack of nine channels scattered with a single ``_fused_scatter_add`` call. Returns the raw `(9, tree_size)` array so callers can index by stat: Index 0..3 — conversion (Normal-Normal, all visitors): 0: sum_hr_c = Σ h_i r_conv_i 1: sum_h2_c = Σ h_i² 2: sum_r2_c = Σ r_conv_i² 3: count_c = n (all visitors contribute) Index 4..8 — severity (Normal-Gamma, converters only, precision-weighted): 4: sum_vr_s = Σ v_i h_i r_sev_i (v_i = precision_i * mask_i) 5: sum_vh2_s = Σ v_i h_i² 6: sum_vr2_s = Σ v_i r_sev_i² 7: sum_logv_s = Σ log(v_i) (effective: log(precision_i)*mask_i) 8: count_s = Σ mask_i Halves the scatter-add launch count per tree per move evaluation relative to the dual ``_hurdle_sufficient_stats*`` calls. """ n = leaf_indices.shape[0] all_true_f = jnp.ones(n, jnp.float32) mask_f = conv_mask.astype(jnp.float32) v = precision * mask_f # zero on non-converters by mask h_sev = basis * mask_f r_sev_masked = resid_sev * mask_f log_v = jnp.log(jnp.maximum(precision, 1e-30)) * mask_f stack = jnp.stack([ # --- conversion channel (mask = all-true) --- basis * resid_conv, # sum_hr_c basis * basis, # sum_h2_c resid_conv * resid_conv, # sum_r2_c all_true_f, # count_c # --- severity channel (mask = converter_mask) --- v * h_sev * resid_sev, # sum_vr_s v * h_sev * basis, # sum_vh2_s v * r_sev_masked * resid_sev, # sum_vr2_s log_v, # sum_logv_s mask_f, # count_s ]) # (9, n) return _fused_scatter_add(leaf_indices, stack, tree_size) def _hurdle_sufficient_stats_joint_nn( leaf_indices: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, tree_size: int, ) -> jax.Array: """Fused conv+sev Normal-Normal sufficient stats per leaf. Used by the ``per_leaf_gamma=False`` path: severity is Normal-Normal with sigma² absorbed into the global ``tau_0`` (no per-obs precision weighting), so the conv (4 stats) and sev (4 stats) sufficient stats can share an 8-channel fused scatter-add. Index 0..3 — conversion (all visitors): 0: sum_hr_c = Σ h_i r_conv_i 1: sum_h2_c = Σ h_i² 2: sum_r2_c = Σ r_conv_i² 3: count_c = n Index 4..7 — severity (converters only): 4: sum_hr_s = Σ_{i: conv_i} h_i r_sev_i 5: sum_h2_s = Σ_{i: conv_i} h_i² 6: sum_r2_s = Σ_{i: conv_i} r_sev_i² 7: count_s = Σ conv_mask_i """ n = leaf_indices.shape[0] all_true_f = jnp.ones(n, jnp.float32) mask_f = conv_mask.astype(jnp.float32) h_sev = basis * mask_f r_sev_masked = resid_sev * mask_f stack = jnp.stack([ # --- conversion channel (all visitors) --- basis * resid_conv, # sum_hr_c basis * basis, # sum_h2_c resid_conv * resid_conv, # sum_r2_c all_true_f, # count_c # --- severity channel (converters only) --- h_sev * resid_sev, # sum_hr_s h_sev * basis, # sum_h2_s r_sev_masked * resid_sev, # sum_r2_s mask_f, # count_s ]) # (8, n) return _fused_scatter_add(leaf_indices, stack, tree_size) def _log_marginal_hurdle_ratio_nn( old_li: jax.Array, prop_li: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, sigma2_sev: jax.Array, tau_conv: jax.Array, tau_sev: jax.Array, tree_size: int, sev_weight: float = 1.0, ) -> jax.Array: """Log marginal-likelihood ratio with Normal-Normal severity. Fast-path used by ``per_leaf_gamma=False``: severity is Normal-Normal at fixed scale ``sigma2_sev = 1/tau_0`` (no per-leaf gamma multiplier and no per-observation precision weighting). Returns ``lml(prop) - lml(old)`` using one fused 8-channel scatter-add per leaf assignment. Mathematically equivalent to:: _log_marginal_hurdle_tree(prop_li, ...) - _log_marginal_hurdle_tree(old_li, ...) """ ss_old = _hurdle_sufficient_stats_joint_nn( old_li, resid_conv, resid_sev, conv_mask, basis, tree_size, ) ss_prop = _hurdle_sufficient_stats_joint_nn( prop_li, resid_conv, resid_sev, conv_mask, basis, tree_size, ) sigma2_one = jnp.float32(1.0) lml_conv_old = jnp.sum(_log_marginal_leaves( ss_old[0], ss_old[1], ss_old[2], ss_old[3], sigma2_one, tau_conv, )) lml_conv_prop = jnp.sum(_log_marginal_leaves( ss_prop[0], ss_prop[1], ss_prop[2], ss_prop[3], sigma2_one, tau_conv, )) lml_sev_old = jnp.sum(_log_marginal_leaves( ss_old[4], ss_old[5], ss_old[6], ss_old[7], sigma2_sev, tau_sev, )) lml_sev_prop = jnp.sum(_log_marginal_leaves( ss_prop[4], ss_prop[5], ss_prop[6], ss_prop[7], sigma2_sev, tau_sev, )) return ( (lml_conv_prop - lml_conv_old) + sev_weight * (lml_sev_prop - lml_sev_old) ) def _hurdle_move_resids_and_counts( leaf_indices_deep: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, tree_size: int, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: """Per-leaf basis-weighted resid sums + leaf counts for one tree. Single fused 4-channel scatter-add into ``(tree_size,)`` bins, keyed by the *deeper* leaf-indices for the tree (post-grow / pre-prune). The bartz-style accept step extracts only ``move.{left, right}`` from these arrays — O(1) per tree per channel — and forms ``total = left + right``. Returns ------- sum_hr_conv, count_conv : (tree_size,) — basis-weighted residual sum and observation count for the conversion channel (all visitors). sum_hr_sev, count_sev : (tree_size,) — same but only converters contribute (severity channel, mask = ``conv_mask``). """ n = leaf_indices_deep.shape[0] all_true_f = jnp.ones(n, jnp.float32) mask_f = conv_mask.astype(jnp.float32) h_sev = basis * mask_f stack = jnp.stack([ basis * resid_conv, # sum_hr_conv all_true_f, # count_conv h_sev * resid_sev, # sum_hr_sev (mask absorbed in h_sev) mask_f, # count_sev ]) # (4, n) out = _fused_scatter_add(leaf_indices_deep, stack, tree_size) return out[0], out[1], out[2], out[3] def _hurdle_move_log_lk_ratio_grow( sum_hr_left: jax.Array, sum_hr_right: jax.Array, sum_hr_total: jax.Array, count_left: jax.Array, count_right: jax.Array, count_total: jax.Array, sigma2: jax.Array, tau_prior: jax.Array, basis_sq: jax.Array, ) -> jax.Array: """O(1) Normal-Normal log-lml ratio for one channel of one tree's GROW move. Returns ``log p(r | tree.split) - log p(r | tree.unsplit)``. For PRUNE moves negate this externally (bartz pattern: same formula computes ``log_lml(deeper) - log_lml(shallower)``; PRUNE = shallower minus deeper). Closed-form Normal-Normal conjugate ratio derived from ``_log_marginal_leaves`` summed at the three involved leaves; the sum_r2 / count*log(2π σ²) terms drop out because ``leaf_total_sum_r2 == sum_r2_left + sum_r2_right``. Parameters ---------- sum_hr_{left,right,total}: basis-weighted residual sums in the affected leaves. count_{left,right,total}: number of contributing observations. sigma2: error variance (1.0 for the conversion / probit channel, ``1/tau_0`` for severity on the per-leaf-gamma=False path). tau_prior: leaf-mean prior variance. basis_sq: ``basis²`` per observation, treated here as a constant. For ±0.5 (tau forest) and 1.0 (mu forest) the basis is uniform over contributing observations, so ``Σ h_i² = count × basis_sq``. """ sum_h2_left = count_left * basis_sq sum_h2_right = count_right * basis_sq sum_h2_total = count_total * basis_sq T_left = sigma2 + sum_h2_left * tau_prior T_right = sigma2 + sum_h2_right * tau_prior T_total = sigma2 + sum_h2_total * tau_prior log_det_diff = 0.5 * jnp.log( sigma2 * T_total / jnp.maximum(T_left * T_right, 1e-30) ) data_fit_diff = (tau_prior / (2.0 * sigma2)) * ( sum_hr_left * sum_hr_left / jnp.maximum(T_left, 1e-30) + sum_hr_right * sum_hr_right / jnp.maximum(T_right, 1e-30) - sum_hr_total * sum_hr_total / jnp.maximum(T_total, 1e-30) ) return log_det_diff + data_fit_diff def _mv_leaf_draw( sum_hr: jax.Array, sum_h2: jax.Array, sigma2: jax.Array, tau_prior: jax.Array, noise: jax.Array, ) -> jax.Array: """Conjugate posterior mean + correlated MV-normal draw for one leaf. Per-leaf Bayesian update for the multivariate Normal-Normal model: r_i | theta ~ N(h_i^T theta, sigma2) theta ~ N(0, tau_prior * I_{K-1}) The posterior is N(mean, P^{-1}) where: P = sum_h2 / sigma2 + (1/tau_prior) * I_{K-1} mean = P^{-1} @ (sum_hr / sigma2) A sample is generated as: draw = mean + M @ noise where ``M = chol(P^{-1})`` (lower-triangular Cholesky factor of the posterior covariance), so ``cov(draw) = M M^T = P^{-1}``. At K-1=1 the draw reduces to the scalar ``mean + noise / sqrt(P)``. Broadcasts over arbitrary leading batch dimensions (batched ``jnp.linalg.cholesky`` and ``jnp.linalg.solve`` handle leading dims). The caller is responsible for zeroing draws for empty leaves. Parameters ---------- sum_hr : (..., K-1) Per-leaf basis-weighted residual sum vectors. sum_h2 : (..., K-1, K-1) Per-leaf sum of outer products of basis rows. sigma2 : scalar Error variance (1.0 for the probit/conversion channel). tau_prior : scalar Isotropic leaf-mean prior variance. noise : (..., K-1) Standard normal noise for each leaf (same shape as ``sum_hr``). Returns ------- (..., K-1) Posterior draw for each leaf. Reference --------- multi-arm-joint-hurdle-bcf change, batch 5a spec. """ K1 = sum_hr.shape[-1] P = sum_h2 / sigma2 + (1.0 / tau_prior) * jnp.eye(K1) b = sum_hr / sigma2 mean = jnp.linalg.solve(P, b[..., None])[..., 0] # P^{-1} (sum_hr/sigma2) M = jnp.linalg.cholesky(jnp.linalg.inv(P)) # lower; M M^T = P^{-1} draw = mean + (M @ noise[..., None])[..., 0] # cov(draw) = M M^T = P^{-1} return draw def _mv_log_lk_ratio_grow( sum_hr_left: jax.Array, sum_h2_left: jax.Array, sum_hr_right: jax.Array, sum_h2_right: jax.Array, sigma2: jax.Array, tau_prior: jax.Array, ) -> jax.Array: """O(1) multivariate Normal-Normal log-lml ratio for one channel's GROW move. Multivariate analogue of ``_hurdle_move_log_lk_ratio_grow`` for a ``(K-1)``-dimensional treatment-effect vector. Returns: log p(r | split) - log p(r | unsplit) computed in O(1) from the left- and right-child sufficient statistics without re-evaluating all leaves. For PRUNE moves negate externally (same pattern as the scalar analogue). The derivation uses: M(h2) = tau_prior * h2 + sigma2 * I_{K-1} with the sum_r2 and count*log(2π σ²) terms cancelling in the split difference (left SSE + right SSE = total SSE identically): logdet_diff = 0.5 * (logdet(M_T) - logdet(M_L) - logdet(M_R)) + 0.5 * (K-1) * log(sigma2) datafit_diff = 0.5 * (tau_prior/sigma2) * ( hr_L^T M_L^{-1} hr_L + hr_R^T M_R^{-1} hr_R - hr_T^T M_T^{-1} hr_T ) where ``hr_T = hr_L + hr_R``, ``h2_T = h2_L + h2_R``. At K-1=1 this is algebraically identical to the scalar ``_hurdle_move_log_lk_ratio_grow`` with ``basis_sq = sum_h2 / count``. Broadcasts over leading batch dimensions; works for both single-leaf ``(K-1,)``/``(K-1, K-1)`` and batched ``(L, K-1)``/``(L, K-1, K-1)`` inputs without vmap. Parameters ---------- sum_hr_left : (..., K-1) Basis-weighted residual sum for the left child leaf. sum_h2_left : (..., K-1, K-1) Sum of outer products for the left child leaf. sum_hr_right : (..., K-1) Basis-weighted residual sum for the right child leaf. sum_h2_right : (..., K-1, K-1) Sum of outer products for the right child leaf. sigma2 : scalar Error variance (1.0 for the probit/conversion channel). tau_prior : scalar Isotropic leaf-mean prior variance. Returns ------- (...,) Log-likelihood ratio for the grow move. Reference --------- multi-arm-joint-hurdle-bcf change, batch 5a spec. """ K1 = sum_hr_left.shape[-1] sum_hr_total = sum_hr_left + sum_hr_right sum_h2_total = sum_h2_left + sum_h2_right def _M(h2: jax.Array) -> jax.Array: return tau_prior * h2 + sigma2 * jnp.eye(K1) def _logdet(h2: jax.Array) -> jax.Array: _, ld = jnp.linalg.slogdet(_M(h2)) return ld def _quad(hr: jax.Array, h2: jax.Array) -> jax.Array: x = jnp.linalg.solve(_M(h2), hr[..., None])[..., 0] return jnp.sum(hr * x, axis=-1) logdet_diff = ( 0.5 * (_logdet(sum_h2_total) - _logdet(sum_h2_left) - _logdet(sum_h2_right)) + 0.5 * K1 * jnp.log(sigma2) ) datafit_diff = 0.5 * (tau_prior / sigma2) * ( _quad(sum_hr_left, sum_h2_left) + _quad(sum_hr_right, sum_h2_right) - _quad(sum_hr_total, sum_h2_total) ) return logdet_diff + datafit_diff def _log_marginal_hurdle_ratio_ng( old_li: jax.Array, prop_li: jax.Array, resid_conv: jax.Array, resid_sev: jax.Array, conv_mask: jax.Array, basis: jax.Array, precision: jax.Array, tau_conv: jax.Array, a_tau: jax.Array, b_tau: jax.Array, kappa: jax.Array, tree_size: int, sev_weight: float = 1.0, ) -> jax.Array: """Log marginal-likelihood ratio between proposed and current tree structure. Mathematically equivalent to:: _log_marginal_hurdle_tree_ng(prop_li, ...) - _log_marginal_hurdle_tree_ng(old_li, ...) but each tree-evaluation uses a single fused scatter-add over both conversion and severity channels (9 stats), halving the per-tree scatter-add launch count vs the dual ``_log_marginal_hurdle_tree_ng`` invocation. Returns a scalar log-ratio suitable for an MH accept step. """ ss_old = _hurdle_sufficient_stats_joint( old_li, resid_conv, resid_sev, conv_mask, basis, precision, tree_size, ) ss_prop = _hurdle_sufficient_stats_joint( prop_li, resid_conv, resid_sev, conv_mask, basis, precision, tree_size, ) sigma2_one = jnp.float32(1.0) lml_conv_old = jnp.sum(_log_marginal_leaves( ss_old[0], ss_old[1], ss_old[2], ss_old[3], sigma2_one, tau_conv, )) lml_conv_prop = jnp.sum(_log_marginal_leaves( ss_prop[0], ss_prop[1], ss_prop[2], ss_prop[3], sigma2_one, tau_conv, )) lml_sev_old = jnp.sum(_log_marginal_leaves_ng( ss_old[4], ss_old[5], ss_old[6], ss_old[7], ss_old[8], a_tau, b_tau, kappa, )) lml_sev_prop = jnp.sum(_log_marginal_leaves_ng( ss_prop[4], ss_prop[5], ss_prop[6], ss_prop[7], ss_prop[8], a_tau, b_tau, kappa, )) return ( (lml_conv_prop - lml_conv_old) + sev_weight * (lml_sev_prop - lml_sev_old) )