pytyche.bcf.lml¶
Pure-math sufficient statistics and log-marginal likelihoods for the GPU BCF.
This module holds the closed-form pieces of the joint hurdle BCF posterior:
basis-weighted sufficient statistics aggregated per leaf, the matching
log-marginal likelihoods (Normal-Normal and Normal-Gamma conjugacy), and the
moment-matched gamma prior used to set (a_tau, b_tau).
It is “pure math” in the sense that every function here depends only on its arguments and on JAX/numpy/scipy primitives — no module-level state, no GPU device handles, no JIT-compiled kernels of its own.
Import graph¶
This module imports _fused_scatter_add from pytyche.bcf.kernels and
nothing else from sibling GPU BCF modules. The orchestrator and the GFR
module import FROM here, never the other way around — see
pytyche.bcf.kernels for the same convention applied at the bottom of the
import graph.
Contents¶
_solve_gamma_prior — Linero moment-matching for (a_gamma, b_gamma).
_hurdle_sufficient_stats — basis-weighted (Normal-Normal) per-leaf stats.
_hurdle_sufficient_stats_ng — precision-weighted (Normal-Gamma) per-leaf stats.
_log_marginal_leaves — per-leaf log-marginal under Normal-Normal conjugacy.
_log_marginal_leaves_ng — per-leaf log-marginal under Normal-Gamma conjugacy.
_log_marginal_hurdle_tree — total joint log-marginal for a hurdle tree (NN severity).
_log_marginal_hurdle_tree_ng — total joint log-marginal for a hurdle tree (NG severity).
Classes
|
Per-tree leaf-posterior factors for one channel. |
|
Per-tree likelihood-ratio factors for one channel of the hurdle lml. |
- class pytyche.bcf.lml.HurdlePreLkv(T_left, T_right, T_total, log_sqrt_term, exp_factor)[source]¶
Bases:
NamedTuplePer-tree likelihood-ratio factors for one channel of the hurdle lml.
Mirrors bartz’s
PreLkVModule as a JAX-pytree NamedTuple. All fields have shape(num_trees,)after the parallel pre-compute stage.Fields T_left/T_right/T_total are the conjugate-Normal-Normal denominators
sigma2 + sum_h2 * tau_priorevaluated at the move’s left, right, and parent leaves; log_sqrt_term is the log-determinant ratio0.5 * log(sigma2 * T_total / (T_left * T_right)); exp_factor is the shared scalartau_prior / (2 * sigma2)that multiplies the residual quadratic in the O(1) ratio.- Parameters:
T_left (
Array)T_right (
Array)T_total (
Array)log_sqrt_term (
Array)exp_factor (
Array)
- T_left: Array¶
Alias for field number 0
- T_right: Array¶
Alias for field number 1
- T_total: Array¶
Alias for field number 2
- log_sqrt_term: Array¶
Alias for field number 3
- exp_factor: Array¶
Alias for field number 4
- class pytyche.bcf.lml.HurdlePreLf(mean_factor, centered_leaves)[source]¶
Bases:
NamedTuplePer-tree leaf-posterior factors for one channel.
- Mirrors bartz’s
PreLf. For each tree’s tree_size leaves: - new_leaf_value[t, l]
= mean_factor[t, l] * sum_resid_at_leaf[t, l] + centered_leaves[t, l]
Both fields zeroed where
count_trees[t, l] == 0so empty leaves get zero alpha (matches the existingjnp.where(count > 0, ..., 0)pattern).- Parameters:
mean_factor (
Array)centered_leaves (
Array)
- mean_factor: Array¶
Alias for field number 0
- centered_leaves: Array¶
Alias for field number 1
- Mirrors bartz’s