Source code for pytyche.diagnostics.convergence

"""General MCMC convergence diagnostics.

Contains pure convergence-diagnostic functions that are not specific to
BCF: R-hat, bulk ESS, posterior stability, and the between-chunk
diagnostic accumulator used by the GPU BCF orchestrator.

Contents
--------
``compute_ess`` — effective sample size via autocorrelation (1D trace).
``compute_rhat`` — Gelman-Rubin R-hat for multi-chain convergence.
``compute_posterior_stability`` — split-half posterior mean comparison.
``_split_rhat`` — split-R-hat for a (num_samples, num_chains) array.
``_bulk_ess`` — per-chain bulk ESS via initial positive autocorrelation.
``_tau_hat_quantiles_from_carry`` — extract tau_hat quantiles from carry tuple.
``_between_chunk_diagnostics`` — per-chunk diagnostic accumulator.
``_compile_diagnostics`` — final diagnostics dict from chunk log.
"""

from __future__ import annotations

import numpy as np

# ---------------------------------------------------------------------------
# Simple 1D ESS (from bcf_diagnostics)
# ---------------------------------------------------------------------------


[docs] def compute_ess(samples: np.ndarray) -> float: """Effective sample size for a 1D trace using autocorrelation. Simple estimator: ESS = S / (1 + 2 * sum(rho_k)) where rho_k are lagged autocorrelations, summed until they become negative. """ if samples.ndim > 1: raise ValueError("compute_ess expects a 1D trace") S = len(samples) if S < 4: return float(S) x = samples - samples.mean() var = float(np.var(x)) if var < 1e-15: return float(S) # Compute autocorrelations using FFT fft_x = np.fft.fft(x, n=2 * S) acf = np.fft.ifft(fft_x * np.conj(fft_x)).real[:S] / (var * S) # Sum positive initial autocorrelations rho_sum = 0.0 for k in range(1, S): if acf[k] < 0: break rho_sum += acf[k] ess = S / (1.0 + 2.0 * rho_sum) return max(1.0, float(ess))
# --------------------------------------------------------------------------- # Gelman-Rubin R-hat (from bcf_diagnostics) # ---------------------------------------------------------------------------
[docs] def compute_rhat(chains: list[np.ndarray]) -> float: """Gelman-Rubin R-hat for multi-chain convergence. Parameters ---------- chains : list of 1D arrays Each element is a trace from one chain (same length preferred). Returns ------- R-hat statistic. Values near 1.0 indicate convergence. Above 1.05 suggests insufficient mixing or burn-in. """ m = len(chains) if m < 2: return float("nan") n = min(len(c) for c in chains) if n < 4: return float("nan") # Trim to common length chains_arr = np.array([c[:n] for c in chains]) # (m, n) chain_means = chains_arr.mean(axis=1) # (m,) # Between-chain variance B = n * np.var(chain_means, ddof=1) # Within-chain variance W = np.mean(np.var(chains_arr, axis=1, ddof=1)) if W < 1e-15: return 1.0 # Pooled variance estimate var_hat = (1 - 1 / n) * W + B / n return float(np.sqrt(var_hat / W))
# --------------------------------------------------------------------------- # Posterior stability (from bcf_diagnostics) # ---------------------------------------------------------------------------
[docs] def compute_posterior_stability( samples: np.ndarray, n_splits: int = 2, ) -> dict[str, float]: """Compare first-half vs second-half posterior means. For a (n, S) array, computes per-visitor mean in each half and returns the RMSE between halves and the correlation. """ if samples.ndim == 1: samples = samples[np.newaxis, :] S = samples.shape[1] half = S // 2 mean_first = samples[:, :half].mean(axis=1) mean_second = samples[:, half:].mean(axis=1) rmse = float(np.sqrt(np.mean((mean_first - mean_second) ** 2))) if len(mean_first) > 2: corr = float(np.corrcoef(mean_first, mean_second)[0, 1]) else: corr = np.nan return {"split_rmse": rmse, "split_corr": corr}
# --------------------------------------------------------------------------- # Vectorised split-R-hat for (num_samples, num_chains) trace (from gpu_bcf_diagnostics) # --------------------------------------------------------------------------- def _split_rhat(chains: np.ndarray) -> float: """Compute split-R-hat for a (num_samples, num_chains) array. Uses the rank-normalized split-R-hat from Vehtari et al. (2021). For simplicity, uses the basic split-R-hat (split each chain in half). """ S, C = chains.shape if S < 4 or C < 2: return float("nan") # Split each chain in half half = S // 2 split_chains = np.concatenate( [chains[:half, :], chains[half:2*half, :]], axis=1, ) # (half, 2*C) M = 2 * C # noqa: F841 — Vehtari-2021 notation; referenced in `# (M,)` below N = half chain_means = split_chains.mean(axis=0) # (M,) grand_mean = chain_means.mean() # noqa: F841 — third member of (B, W, grand_mean) B = N * np.var(chain_means, ddof=1) W = np.mean(np.var(split_chains, axis=0, ddof=1)) if W < 1e-15: return 1.0 var_hat = (N - 1) / N * W + B / N return float(np.sqrt(var_hat / W)) # --------------------------------------------------------------------------- # Per-chain bulk ESS (from gpu_bcf_diagnostics) # --------------------------------------------------------------------------- def _bulk_ess(chains: np.ndarray) -> np.ndarray: """Compute bulk ESS per chain using the initial positive autocorrelation. Returns array of shape (num_chains,). """ S, C = chains.shape ess = np.zeros(C) for c in range(C): x = chains[:, c] x = x - x.mean() var = np.var(x, ddof=0) if var < 1e-15: ess[c] = float(S) continue # Compute autocorrelations up to lag S//2 max_lag = S // 2 rho = np.zeros(max_lag) for lag in range(1, max_lag): rho[lag] = np.mean(x[:-lag] * x[lag:]) / var # Sum positive pairs of autocorrelations tau = 1.0 for lag in range(1, max_lag - 1, 2): pair = rho[lag] + rho[lag + 1] if pair < 0: break tau += 2 * pair ess[c] = S / tau return ess # --------------------------------------------------------------------------- # Tau-hat quantile extraction (from gpu_bcf_diagnostics) # --------------------------------------------------------------------------- def _tau_hat_quantiles_from_carry(carry) -> np.ndarray: """Extract tau_hat quantiles from carry tuple. carry[14] is tau_hat with shape (num_chains, n). Returns (5,) array: [q05, q25, q50, q75, q95] averaged across chains. """ tau_hat_arr = np.array(carry[14]) # (num_chains, n) qs = np.quantile(tau_hat_arr, [0.05, 0.25, 0.5, 0.75, 0.95], axis=1) # (5, C) return qs.mean(axis=1) # (5,) averaged across chains # --------------------------------------------------------------------------- # Between-chunk diagnostic accumulator (from gpu_bcf_diagnostics) # --------------------------------------------------------------------------- def _between_chunk_diagnostics( sigma2_trace: np.ndarray, chunk_idx: int, total_chunks: int, phase: str, history: list, tau_hat_quantiles: np.ndarray | None = None, ) -> dict: """Compute between-chunk diagnostics (pure — callers own presentation). sigma2_trace: (chunk_length, num_chains) tau_hat_quantiles: (5,) — [q05, q25, q50, q75, q95] of tau_hat product total_chunks: carried into the entry's chunk context for consumers rendering progress; not used in any computation here. """ num_chains = sigma2_trace.shape[1] entry = { "chunk": chunk_idx, "total_chunks": total_chunks, "phase": phase, "sigma2_mean": float(sigma2_trace.mean()), } if tau_hat_quantiles is not None: entry["tau_hat_q05"] = float(tau_hat_quantiles[0]) entry["tau_hat_q50"] = float(tau_hat_quantiles[2]) entry["tau_hat_q95"] = float(tau_hat_quantiles[4]) q05, q95 = tau_hat_quantiles[0], tau_hat_quantiles[4] # A >1000x tau_hat quantile spread marks a divergent per-leaf # precision posterior — recorded so downstream consumers (progress # rendering, diagnostics compilation) can surface it. entry["tau_hat_divergent"] = bool(q05 > 0 and q95 / q05 > 1000) if num_chains >= 2: # Accumulate sigma2 within current phase for R-hat # (reset at burnin→MCMC boundary so burnin non-stationarity # doesn't inflate MCMC-phase R-hat) same_phase = [ h.get("_sigma2_raw", np.empty((0, num_chains))) for h in history if h.get("phase") == phase ] all_s2 = np.concatenate(same_phase + [sigma2_trace], axis=0) entry["_sigma2_raw"] = sigma2_trace # store for accumulation rhat = _split_rhat(all_s2) ess = _bulk_ess(all_s2) entry["rhat_sigma2"] = rhat entry["ess_per_chain"] = ess.tolist() else: entry["_sigma2_raw"] = sigma2_trace return entry # --------------------------------------------------------------------------- # Final diagnostics compiler (from gpu_bcf_diagnostics) # --------------------------------------------------------------------------- def _compile_diagnostics(diag_log: list, s2_all: np.ndarray) -> dict: """Compile final diagnostics dict from chunk log. s2_all: (S_per_chain, num_chains) — the full MCMC sigma2 trace. Preserves the chunk-by-chunk trajectory for ESS accumulation curves and burn-in adequacy assessment. """ result = {} C = s2_all.shape[1] if C >= 2: result["rhat_sigma2"] = _split_rhat(s2_all) result["ess_per_chain"] = _bulk_ess(s2_all).tolist() result["sigma2_mean"] = float(s2_all.mean()) result["sigma2_std"] = float(s2_all.std()) result["num_chunks"] = len(diag_log) # Chunk-by-chunk trajectory (strip raw arrays, keep scalars) trajectory = [] for entry in diag_log: t = { "chunk": entry["chunk"], "phase": entry["phase"], "sigma2_mean": entry["sigma2_mean"], } if "rhat_sigma2" in entry: t["rhat_sigma2"] = entry["rhat_sigma2"] if "ess_per_chain" in entry: t["ess_per_chain"] = entry["ess_per_chain"] if "mu_accept_rate" in entry: t["mu_accept_rate"] = entry["mu_accept_rate"] if "tau_accept_rate" in entry: t["tau_accept_rate"] = entry["tau_accept_rate"] for k in ( "tau_hat_q05", "tau_hat_q50", "tau_hat_q95", "tau_hat_divergent", ): if k in entry: t[k] = entry[k] trajectory.append(t) result["trajectory"] = trajectory # Acceptance rate summaries (mean over MCMC chunks only) mcmc_entries = [e for e in diag_log if e["phase"] == "mcmc"] if mcmc_entries and "mu_accept_rate" in mcmc_entries[0]: result["mu_accept_rate"] = float(np.mean( [e["mu_accept_rate"] for e in mcmc_entries], )) result["tau_accept_rate"] = float(np.mean( [e["tau_accept_rate"] for e in mcmc_entries], )) return result