pytyche.bcf.lml

Pure-math sufficient statistics and log-marginal likelihoods for the GPU BCF.

This module holds the closed-form pieces of the joint hurdle BCF posterior: basis-weighted sufficient statistics aggregated per leaf, the matching log-marginal likelihoods (Normal-Normal and Normal-Gamma conjugacy), and the moment-matched gamma prior used to set (a_tau, b_tau).

It is “pure math” in the sense that every function here depends only on its arguments and on JAX/numpy/scipy primitives — no module-level state, no GPU device handles, no JIT-compiled kernels of its own.

Import graph

This module imports _fused_scatter_add from pytyche.bcf.kernels and nothing else from sibling GPU BCF modules. The orchestrator and the GFR module import FROM here, never the other way around — see pytyche.bcf.kernels for the same convention applied at the bottom of the import graph.

Contents

_solve_gamma_prior — Linero moment-matching for (a_gamma, b_gamma). _hurdle_sufficient_stats — basis-weighted (Normal-Normal) per-leaf stats. _hurdle_sufficient_stats_ng — precision-weighted (Normal-Gamma) per-leaf stats. _log_marginal_leaves — per-leaf log-marginal under Normal-Normal conjugacy. _log_marginal_leaves_ng — per-leaf log-marginal under Normal-Gamma conjugacy. _log_marginal_hurdle_tree — total joint log-marginal for a hurdle tree (NN severity). _log_marginal_hurdle_tree_ng — total joint log-marginal for a hurdle tree (NG severity).

Classes

HurdlePreLf(mean_factor, centered_leaves)

Per-tree leaf-posterior factors for one channel.

HurdlePreLkv(T_left, T_right, T_total, ...)

Per-tree likelihood-ratio factors for one channel of the hurdle lml.

class pytyche.bcf.lml.HurdlePreLkv(T_left, T_right, T_total, log_sqrt_term, exp_factor)[source]

Bases: NamedTuple

Per-tree likelihood-ratio factors for one channel of the hurdle lml.

Mirrors bartz’s PreLkV Module as a JAX-pytree NamedTuple. All fields have shape (num_trees,) after the parallel pre-compute stage.

Fields T_left/T_right/T_total are the conjugate-Normal-Normal denominators sigma2 + sum_h2 * tau_prior evaluated at the move’s left, right, and parent leaves; log_sqrt_term is the log-determinant ratio 0.5 * log(sigma2 * T_total / (T_left * T_right)); exp_factor is the shared scalar tau_prior / (2 * sigma2) that multiplies the residual quadratic in the O(1) ratio.

Parameters:
  • T_left (Array)

  • T_right (Array)

  • T_total (Array)

  • log_sqrt_term (Array)

  • exp_factor (Array)

T_left: Array

Alias for field number 0

T_right: Array

Alias for field number 1

T_total: Array

Alias for field number 2

log_sqrt_term: Array

Alias for field number 3

exp_factor: Array

Alias for field number 4

class pytyche.bcf.lml.HurdlePreLf(mean_factor, centered_leaves)[source]

Bases: NamedTuple

Per-tree leaf-posterior factors for one channel.

Mirrors bartz’s PreLf. For each tree’s tree_size leaves:
new_leaf_value[t, l]

= mean_factor[t, l] * sum_resid_at_leaf[t, l] + centered_leaves[t, l]

Both fields zeroed where count_trees[t, l] == 0 so empty leaves get zero alpha (matches the existing jnp.where(count > 0, ..., 0) pattern).

Parameters:
  • mean_factor (Array)

  • centered_leaves (Array)

mean_factor: Array

Alias for field number 0

centered_leaves: Array

Alias for field number 1