pytyche.diagnostics.convergence

General MCMC convergence diagnostics.

Contains pure convergence-diagnostic functions that are not specific to BCF: R-hat, bulk ESS, posterior stability, and the between-chunk diagnostic accumulator used by the GPU BCF orchestrator.

Contents

compute_ess — effective sample size via autocorrelation (1D trace). compute_rhat — Gelman-Rubin R-hat for multi-chain convergence. compute_posterior_stability — split-half posterior mean comparison. _split_rhat — split-R-hat for a (num_samples, num_chains) array. _bulk_ess — per-chain bulk ESS via initial positive autocorrelation. _tau_hat_quantiles_from_carry — extract tau_hat quantiles from carry tuple. _between_chunk_diagnostics — per-chunk diagnostic accumulator. _compile_diagnostics — final diagnostics dict from chunk log.

Functions

compute_ess(samples)

Effective sample size for a 1D trace using autocorrelation.

compute_posterior_stability(samples[, n_splits])

Compare first-half vs second-half posterior means.

compute_rhat(chains)

Gelman-Rubin R-hat for multi-chain convergence.

pytyche.diagnostics.convergence.compute_ess(samples)[source]

Effective sample size for a 1D trace using autocorrelation.

Simple estimator: ESS = S / (1 + 2 * sum(rho_k)) where rho_k are lagged autocorrelations, summed until they become negative.

Parameters:

samples (ndarray)

Return type:

float

pytyche.diagnostics.convergence.compute_rhat(chains)[source]

Gelman-Rubin R-hat for multi-chain convergence.

Parameters:

chains (list[ndarray]) – Each element is a trace from one chain (same length preferred).

Return type:

float

Returns:

  • R-hat statistic. Values near 1.0 indicate convergence.

  • Above 1.05 suggests insufficient mixing or burn-in.

pytyche.diagnostics.convergence.compute_posterior_stability(samples, n_splits=2)[source]

Compare first-half vs second-half posterior means.

For a (n, S) array, computes per-visitor mean in each half and returns the RMSE between halves and the correlation.

Parameters:
  • samples (ndarray)

  • n_splits (int)

Return type:

dict[str, float]