"""General MCMC convergence diagnostics.
Contains pure convergence-diagnostic functions that are not specific to
BCF: R-hat, bulk ESS, posterior stability, and the between-chunk
diagnostic accumulator used by the GPU BCF orchestrator.
Contents
--------
``compute_ess`` — effective sample size via autocorrelation (1D trace).
``compute_rhat`` — Gelman-Rubin R-hat for multi-chain convergence.
``compute_posterior_stability`` — split-half posterior mean comparison.
``_split_rhat`` — split-R-hat for a (num_samples, num_chains) array.
``_bulk_ess`` — per-chain bulk ESS via initial positive autocorrelation.
``_tau_hat_quantiles_from_carry`` — extract tau_hat quantiles from carry tuple.
``_between_chunk_diagnostics`` — per-chunk diagnostic accumulator.
``_compile_diagnostics`` — final diagnostics dict from chunk log.
"""
from __future__ import annotations
import numpy as np
# ---------------------------------------------------------------------------
# Simple 1D ESS (from bcf_diagnostics)
# ---------------------------------------------------------------------------
[docs]
def compute_ess(samples: np.ndarray) -> float:
"""Effective sample size for a 1D trace using autocorrelation.
Simple estimator: ESS = S / (1 + 2 * sum(rho_k)) where rho_k are
lagged autocorrelations, summed until they become negative.
"""
if samples.ndim > 1:
raise ValueError("compute_ess expects a 1D trace")
S = len(samples)
if S < 4:
return float(S)
x = samples - samples.mean()
var = float(np.var(x))
if var < 1e-15:
return float(S)
# Compute autocorrelations using FFT
fft_x = np.fft.fft(x, n=2 * S)
acf = np.fft.ifft(fft_x * np.conj(fft_x)).real[:S] / (var * S)
# Sum positive initial autocorrelations
rho_sum = 0.0
for k in range(1, S):
if acf[k] < 0:
break
rho_sum += acf[k]
ess = S / (1.0 + 2.0 * rho_sum)
return max(1.0, float(ess))
# ---------------------------------------------------------------------------
# Gelman-Rubin R-hat (from bcf_diagnostics)
# ---------------------------------------------------------------------------
[docs]
def compute_rhat(chains: list[np.ndarray]) -> float:
"""Gelman-Rubin R-hat for multi-chain convergence.
Parameters
----------
chains : list of 1D arrays
Each element is a trace from one chain (same length preferred).
Returns
-------
R-hat statistic. Values near 1.0 indicate convergence.
Above 1.05 suggests insufficient mixing or burn-in.
"""
m = len(chains)
if m < 2:
return float("nan")
n = min(len(c) for c in chains)
if n < 4:
return float("nan")
# Trim to common length
chains_arr = np.array([c[:n] for c in chains]) # (m, n)
chain_means = chains_arr.mean(axis=1) # (m,)
# Between-chain variance
B = n * np.var(chain_means, ddof=1)
# Within-chain variance
W = np.mean(np.var(chains_arr, axis=1, ddof=1))
if W < 1e-15:
return 1.0
# Pooled variance estimate
var_hat = (1 - 1 / n) * W + B / n
return float(np.sqrt(var_hat / W))
# ---------------------------------------------------------------------------
# Posterior stability (from bcf_diagnostics)
# ---------------------------------------------------------------------------
[docs]
def compute_posterior_stability(
samples: np.ndarray, n_splits: int = 2,
) -> dict[str, float]:
"""Compare first-half vs second-half posterior means.
For a (n, S) array, computes per-visitor mean in each half and returns
the RMSE between halves and the correlation.
"""
if samples.ndim == 1:
samples = samples[np.newaxis, :]
S = samples.shape[1]
half = S // 2
mean_first = samples[:, :half].mean(axis=1)
mean_second = samples[:, half:].mean(axis=1)
rmse = float(np.sqrt(np.mean((mean_first - mean_second) ** 2)))
if len(mean_first) > 2:
corr = float(np.corrcoef(mean_first, mean_second)[0, 1])
else:
corr = np.nan
return {"split_rmse": rmse, "split_corr": corr}
# ---------------------------------------------------------------------------
# Vectorised split-R-hat for (num_samples, num_chains) trace (from gpu_bcf_diagnostics)
# ---------------------------------------------------------------------------
def _split_rhat(chains: np.ndarray) -> float:
"""Compute split-R-hat for a (num_samples, num_chains) array.
Uses the rank-normalized split-R-hat from Vehtari et al. (2021).
For simplicity, uses the basic split-R-hat (split each chain in half).
"""
S, C = chains.shape
if S < 4 or C < 2:
return float("nan")
# Split each chain in half
half = S // 2
split_chains = np.concatenate(
[chains[:half, :], chains[half:2*half, :]], axis=1,
) # (half, 2*C)
M = 2 * C # noqa: F841 — Vehtari-2021 notation; referenced in `# (M,)` below
N = half
chain_means = split_chains.mean(axis=0) # (M,)
grand_mean = chain_means.mean() # noqa: F841 — third member of (B, W, grand_mean)
B = N * np.var(chain_means, ddof=1)
W = np.mean(np.var(split_chains, axis=0, ddof=1))
if W < 1e-15:
return 1.0
var_hat = (N - 1) / N * W + B / N
return float(np.sqrt(var_hat / W))
# ---------------------------------------------------------------------------
# Per-chain bulk ESS (from gpu_bcf_diagnostics)
# ---------------------------------------------------------------------------
def _bulk_ess(chains: np.ndarray) -> np.ndarray:
"""Compute bulk ESS per chain using the initial positive autocorrelation.
Returns array of shape (num_chains,).
"""
S, C = chains.shape
ess = np.zeros(C)
for c in range(C):
x = chains[:, c]
x = x - x.mean()
var = np.var(x, ddof=0)
if var < 1e-15:
ess[c] = float(S)
continue
# Compute autocorrelations up to lag S//2
max_lag = S // 2
rho = np.zeros(max_lag)
for lag in range(1, max_lag):
rho[lag] = np.mean(x[:-lag] * x[lag:]) / var
# Sum positive pairs of autocorrelations
tau = 1.0
for lag in range(1, max_lag - 1, 2):
pair = rho[lag] + rho[lag + 1]
if pair < 0:
break
tau += 2 * pair
ess[c] = S / tau
return ess
# ---------------------------------------------------------------------------
# Tau-hat quantile extraction (from gpu_bcf_diagnostics)
# ---------------------------------------------------------------------------
def _tau_hat_quantiles_from_carry(carry) -> np.ndarray:
"""Extract tau_hat quantiles from carry tuple.
carry[14] is tau_hat with shape (num_chains, n).
Returns (5,) array: [q05, q25, q50, q75, q95] averaged across chains.
"""
tau_hat_arr = np.array(carry[14]) # (num_chains, n)
qs = np.quantile(tau_hat_arr, [0.05, 0.25, 0.5, 0.75, 0.95], axis=1) # (5, C)
return qs.mean(axis=1) # (5,) averaged across chains
# ---------------------------------------------------------------------------
# Between-chunk diagnostic accumulator (from gpu_bcf_diagnostics)
# ---------------------------------------------------------------------------
def _between_chunk_diagnostics(
sigma2_trace: np.ndarray,
chunk_idx: int,
total_chunks: int,
phase: str,
history: list,
tau_hat_quantiles: np.ndarray | None = None,
) -> dict:
"""Compute between-chunk diagnostics (pure — callers own presentation).
sigma2_trace: (chunk_length, num_chains)
tau_hat_quantiles: (5,) — [q05, q25, q50, q75, q95] of tau_hat product
total_chunks: carried into the entry's chunk context for consumers
rendering progress; not used in any computation here.
"""
num_chains = sigma2_trace.shape[1]
entry = {
"chunk": chunk_idx,
"total_chunks": total_chunks,
"phase": phase,
"sigma2_mean": float(sigma2_trace.mean()),
}
if tau_hat_quantiles is not None:
entry["tau_hat_q05"] = float(tau_hat_quantiles[0])
entry["tau_hat_q50"] = float(tau_hat_quantiles[2])
entry["tau_hat_q95"] = float(tau_hat_quantiles[4])
q05, q95 = tau_hat_quantiles[0], tau_hat_quantiles[4]
# A >1000x tau_hat quantile spread marks a divergent per-leaf
# precision posterior — recorded so downstream consumers (progress
# rendering, diagnostics compilation) can surface it.
entry["tau_hat_divergent"] = bool(q05 > 0 and q95 / q05 > 1000)
if num_chains >= 2:
# Accumulate sigma2 within current phase for R-hat
# (reset at burnin→MCMC boundary so burnin non-stationarity
# doesn't inflate MCMC-phase R-hat)
same_phase = [
h.get("_sigma2_raw", np.empty((0, num_chains)))
for h in history if h.get("phase") == phase
]
all_s2 = np.concatenate(same_phase + [sigma2_trace], axis=0)
entry["_sigma2_raw"] = sigma2_trace # store for accumulation
rhat = _split_rhat(all_s2)
ess = _bulk_ess(all_s2)
entry["rhat_sigma2"] = rhat
entry["ess_per_chain"] = ess.tolist()
else:
entry["_sigma2_raw"] = sigma2_trace
return entry
# ---------------------------------------------------------------------------
# Final diagnostics compiler (from gpu_bcf_diagnostics)
# ---------------------------------------------------------------------------
def _compile_diagnostics(diag_log: list, s2_all: np.ndarray) -> dict:
"""Compile final diagnostics dict from chunk log.
s2_all: (S_per_chain, num_chains) — the full MCMC sigma2 trace.
Preserves the chunk-by-chunk trajectory for ESS accumulation curves
and burn-in adequacy assessment.
"""
result = {}
C = s2_all.shape[1]
if C >= 2:
result["rhat_sigma2"] = _split_rhat(s2_all)
result["ess_per_chain"] = _bulk_ess(s2_all).tolist()
result["sigma2_mean"] = float(s2_all.mean())
result["sigma2_std"] = float(s2_all.std())
result["num_chunks"] = len(diag_log)
# Chunk-by-chunk trajectory (strip raw arrays, keep scalars)
trajectory = []
for entry in diag_log:
t = {
"chunk": entry["chunk"],
"phase": entry["phase"],
"sigma2_mean": entry["sigma2_mean"],
}
if "rhat_sigma2" in entry:
t["rhat_sigma2"] = entry["rhat_sigma2"]
if "ess_per_chain" in entry:
t["ess_per_chain"] = entry["ess_per_chain"]
if "mu_accept_rate" in entry:
t["mu_accept_rate"] = entry["mu_accept_rate"]
if "tau_accept_rate" in entry:
t["tau_accept_rate"] = entry["tau_accept_rate"]
for k in (
"tau_hat_q05",
"tau_hat_q50",
"tau_hat_q95",
"tau_hat_divergent",
):
if k in entry:
t[k] = entry[k]
trajectory.append(t)
result["trajectory"] = trajectory
# Acceptance rate summaries (mean over MCMC chunks only)
mcmc_entries = [e for e in diag_log if e["phase"] == "mcmc"]
if mcmc_entries and "mu_accept_rate" in mcmc_entries[0]:
result["mu_accept_rate"] = float(np.mean(
[e["mu_accept_rate"] for e in mcmc_entries],
))
result["tau_accept_rate"] = float(np.mean(
[e["tau_accept_rate"] for e in mcmc_entries],
))
return result