"""BCF posterior calibration diagnostics.
Structured diagnostic workflow for BCF model development, following
Betancourt's principled Bayesian workflow adapted for BART/BCF:
1. Computational faithfulness — σ² convergence, ESS, stability.
2. Retrodictive checks — posterior predictive vs observed.
3. Calibration — coverage, P(τ>0) calibration (requires ground truth).
4. Model critique — channel attribution, quintile calibration.
Public API
----------
- ``BCFDiagnosticData`` — structured container for posterior samples
and metadata.
- ``extract_diagnostics_joint(model, ...)`` — populate from
``HurdleBCFModel``.
- ``extract_diagnostics_proto(result, ...)`` — populate from
``HurdleProtoResult``.
- ``extract_diagnostics_gpu(result, ...)`` — populate from
``HurdleBCFResult``.
- ``compute_*(...)`` — pure diagnostic computation functions.
- ``render_diagnostic_report(diag, run_dir)`` — terminal report to disk.
"""
from __future__ import annotations
import dataclasses
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
import numpy as np
from scipy.stats import kstest, spearmanr
from scipy.stats import norm as sp_norm
if TYPE_CHECKING:
from pytyche.bcf.diagnostics.topology import TopologyHistory
from pytyche.calibrate._coverage_levels import COVERAGE_LEVELS, level_suffix # noqa: F401
from pytyche.diagnostics.convergence import (
compute_ess,
compute_posterior_stability, # noqa: F401 — re-exported
compute_rhat,
)
# ---------------------------------------------------------------------------
# Pre-sort helpers for fast quantile computation
# ---------------------------------------------------------------------------
[docs]
def presort_samples(samples: np.ndarray) -> np.ndarray:
"""Sort posterior samples along axis=1 for O(1) quantile lookup.
Pre-sorting once allows all subsequent quantile operations to use
simple index lookup instead of re-sorting the full (n, S) array.
"""
return np.sort(samples, axis=1)
[docs]
def quantiles_from_sorted(
sorted_samples: np.ndarray,
alphas: Sequence[float],
) -> dict[float, np.ndarray]:
"""Compute all quantiles in one fancy-index op from pre-sorted (n, S).
Uses floor indexing (equivalent to numpy method='lower'). For boolean
coverage checks the interpolation method is immaterial.
"""
S = sorted_samples.shape[1]
indices = [min(int(a * S), S - 1) for a in alphas]
all_q = sorted_samples[:, indices] # (n, len(alphas))
return {a: all_q[:, i] for i, a in enumerate(alphas)}
# ---------------------------------------------------------------------------
# Diagnostic data container
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass
class BCFDiagnosticData:
"""Structured container for all posterior samples from a BCF fit.
Normalizes both estimator APIs (joint, proto, gpu) into a common format
for downstream diagnostic functions.
Per-visitor posterior samples have shape (n, S) where n is the number
of visitors and S is the number of retained MCMC samples.
Channel convention:
- Binary channel: probit scale (mu_b, tau_b → Phi(mu_b + b*tau_b))
- Continuous channel: log-revenue scale (mu_c, tau_c)
- Composed: RPV CATE in $/visitor
"""
# Per-visitor posterior samples (n, S)
tau_binary_samples: np.ndarray # probit-scale binary treatment effect
tau_cont_samples: np.ndarray # log-scale continuous treatment effect
rpv_cate_samples: np.ndarray # composed RPV CATE samples ($/visitor)
mu_binary_samples: np.ndarray # probit-scale prognostic
mu_cont_samples: np.ndarray # log-scale prognostic
# Per-visitor derived channel samples (n, S) — None if not retained
p0_samples: np.ndarray | None # P(convert | control) per sample
p1_samples: np.ndarray | None # P(convert | treated) per sample
# Scalar traces (S,)
sigma2_samples: np.ndarray # global variance trace
# Observed data
Z: np.ndarray # (n,) treatment assignment
Y_obs: np.ndarray # (n,) observed revenue
converter_mask: np.ndarray # (n,) boolean converter indicator
# Metadata
n_burn: int # burn-in iterations already discarded
estimator: str # "joint" | "proto" | "gpu"
num_chains: int = 1 # chain count (samples are interleaved)
# Per-visitor severity channel samples (n, S) — None if not retained
sev0_samples: np.ndarray | None = None # E[sev | control, convert] per sample
sev1_samples: np.ndarray | None = None # E[sev | treated, convert] per sample
# Per-visitor propensity (optional — needed for conformal calibration)
propensity: np.ndarray | None = None # (n,) P(Z=1|X)
# Channel-level posterior means (n,) — populated by GPU path (summary mode)
p0_mean: np.ndarray | None = None # E[Φ(μ_b + b₀·τ_b)]
p1_mean: np.ndarray | None = None # E[Φ(μ_b + b₁·τ_b)]
sev0_mean: np.ndarray | None = None # E[exp(μ_c + b₀·τ_c + σ²/2)]
sev1_mean: np.ndarray | None = None # E[exp(μ_c + b₁·τ_c + σ²/2)]
# GPU-path diagnostics (PPC, per-chain stats computed during scan)
gpu_diagnostics: dict | None = None
# Ground truth (optional — available in simulation)
true_rpv_cate: np.ndarray | None = None
true_p0: np.ndarray | None = None
true_p1: np.ndarray | None = None
true_m0: np.ndarray | None = None
true_m1: np.ndarray | None = None
cluster_id: np.ndarray | None = None
# Optional retained topology trace (populated when the producing fit was
# configured with GPUBCFConfig.retain_topology_history=True). When None
# (default), compute_slim_diagnostics emits no topology.* keys.
topology_history: TopologyHistory | None = None
# ---------------------------------------------------------------------------
# Extraction: Joint hurdle BCF (HurdleBCFModel)
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Extraction: Proto hurdle BCF (HurdleProtoResult)
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Extraction: GPU joint hurdle BCF (HurdleBCFResult)
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Computational faithfulness diagnostics
# ---------------------------------------------------------------------------
[docs]
def compute_sigma2_trace(diag: BCFDiagnosticData) -> dict[str, float]:
"""Convergence summary for the global variance trace.
Returns mean, final value, split-half ratio (first-half mean / second-half
mean), and trend (linear slope normalized by mean).
"""
s2 = diag.sigma2_samples
S = len(s2)
half = S // 2
first_half_mean = float(s2[:half].mean())
second_half_mean = float(s2[half:].mean())
overall_mean = float(s2.mean())
# Split-half ratio: should be near 1.0 if converged
ratio = first_half_mean / second_half_mean if second_half_mean > 0 else np.nan
# Linear trend: slope / mean
x = np.arange(S, dtype=float)
slope = float(np.polyfit(x, s2, 1)[0])
trend = slope * S / overall_mean if overall_mean > 0 else np.nan
return {
"mean": overall_mean,
"final": float(s2[-1]),
"split_half_ratio": ratio,
"trend": trend,
"n_samples": S,
}
# ---------------------------------------------------------------------------
# Retrodictive checks (posterior predictive)
# ---------------------------------------------------------------------------
[docs]
def compute_ppc_binary(diag: BCFDiagnosticData, n_bins: int = 10) -> dict[str, Any] | None:
"""Posterior predicted vs observed conversion by decile of predicted prob.
Bins visitors by posterior mean P(convert), computes observed conversion
rate in each bin, and compares to the predicted rate.
Returns None when sample-level p0/p1 are unavailable (GPU summary mode).
"""
if diag.p0_samples is None or diag.p1_samples is None:
return None
# Posterior mean conversion probability per visitor
p_mean = np.where(
diag.Z[:, np.newaxis],
diag.p1_samples,
diag.p0_samples,
).mean(axis=1) # (n,)
observed_conv = diag.converter_mask.astype(float)
# Bin by predicted probability
bin_edges = np.quantile(p_mean, np.linspace(0, 1, n_bins + 1))
bin_edges[-1] += 1e-10 # include max
bins = np.digitize(p_mean, bin_edges) - 1
bins = np.clip(bins, 0, n_bins - 1)
predicted_rates = []
observed_rates = []
bin_sizes = []
for b in range(n_bins):
mask = bins == b
if mask.sum() == 0:
continue
predicted_rates.append(float(p_mean[mask].mean()))
observed_rates.append(float(observed_conv[mask].mean()))
bin_sizes.append(int(mask.sum()))
return {
"predicted_rates": predicted_rates,
"observed_rates": observed_rates,
"bin_sizes": bin_sizes,
}
[docs]
def compute_ppc_continuous(diag: BCFDiagnosticData) -> dict[str, Any] | None:
"""Posterior predicted vs observed log-revenue for converters.
Returns None when channel-level samples are unavailable (GPU summary mode).
"""
mask = diag.converter_mask
if mask.sum() == 0:
return {"observed_log_y": [], "predicted_log_y_mean": np.nan}
# Need channel-level samples for PPC — not available in GPU summary mode
if diag.mu_cont_samples.shape[1] == 0:
return None
observed_log_y = np.log(np.clip(diag.Y_obs[mask], 1e-10, None))
# Predicted log-revenue: mu_c + basis * tau_c for the assigned arm
b0, b1 = -0.5, 0.5
if diag.estimator in ("joint", "gpu"):
basis = np.where(diag.Z, b1, b0)
pred_log_y = (
diag.mu_cont_samples[mask]
+ basis[mask, np.newaxis] * diag.tau_cont_samples[mask]
).mean(axis=1)
else:
# Proto: mu_cont_samples IS the composite f_sev
pred_log_y = diag.mu_cont_samples[mask].mean(axis=1)
return {
"observed_log_y": observed_log_y.tolist(),
"predicted_log_y": pred_log_y.tolist(),
"observed_mean": float(observed_log_y.mean()),
"predicted_mean": float(pred_log_y.mean()),
}
# ---------------------------------------------------------------------------
# Calibration diagnostics (require ground truth)
# ---------------------------------------------------------------------------
[docs]
def compute_coverage(
samples: np.ndarray,
truth: np.ndarray,
levels: tuple[float, ...] = (0.50, 0.80, 0.90, 0.95),
*,
sorted_samples: np.ndarray | None = None,
) -> dict[str, float]:
"""Actual vs nominal coverage at given credible interval levels.
Parameters
----------
samples : (n, S)
Posterior samples per visitor.
truth : (n,)
True values.
levels :
Nominal coverage levels.
sorted_samples : (n, S), optional
Pre-sorted samples for fast quantile lookup. When provided,
``samples`` is ignored for quantile computation.
"""
if sorted_samples is not None:
alphas = []
for level in levels:
alpha = (1 - level) / 2
alphas.extend([alpha, 1 - alpha])
q_map = quantiles_from_sorted(sorted_samples, alphas)
result = {}
for level in levels:
alpha = (1 - level) / 2
lo = q_map[alpha]
hi = q_map[1 - alpha]
covered = (truth >= lo) & (truth <= hi)
sfx = level_suffix(level)
result[f"coverage_{sfx}"] = float(covered.mean())
result[f"nominal_{sfx}"] = level
return result
result = {}
for level in levels:
alpha = (1 - level) / 2
lo = np.quantile(samples, alpha, axis=1)
hi = np.quantile(samples, 1 - alpha, axis=1)
covered = (truth >= lo) & (truth <= hi)
sfx = level_suffix(level)
result[f"coverage_{sfx}"] = float(covered.mean())
result[f"nominal_{sfx}"] = level
return result
[docs]
def compute_calibration_curve(
p_positive: np.ndarray,
truly_positive: np.ndarray,
n_bins: int = 10,
) -> dict[str, list[float]]:
"""Binned P(τ>0) vs actual fraction positive.
Parameters
----------
p_positive : (n,)
Posterior P(τ>0) per visitor.
truly_positive : (n,)
Boolean: is true τ > 0?
n_bins :
Number of bins.
"""
bin_edges = np.linspace(0, 1, n_bins + 1)
bin_idx = np.digitize(p_positive, bin_edges) - 1
bin_idx = np.clip(bin_idx, 0, n_bins - 1)
predicted = []
actual = []
sizes = []
for b in range(n_bins):
mask = bin_idx == b
if mask.sum() < 5:
continue
predicted.append(float(p_positive[mask].mean()))
actual.append(float(truly_positive[mask].mean()))
sizes.append(int(mask.sum()))
return {"predicted": predicted, "actual": actual, "sizes": sizes}
[docs]
def compute_segment_diagnostics(
rpv_cate_samples: np.ndarray,
true_rpv_cate: np.ndarray,
X: np.ndarray,
feature_names: list[str],
*,
depth: int = 3,
levels: tuple[float, ...] = (0.05, 0.10, 0.20, 0.50, 0.80, 0.90, 0.95, 0.975, 0.99, 0.995, 0.999, 0.9995),
n_ptau_bins: int = 10,
) -> dict[str, float]:
"""Segment-level coverage and P(τ>0) calibration.
Fits a policy tree on the BCF posterior means, then for each leaf segment
computes coverage and P(τ>0) accuracy at the *segment-mean* level. This
produces the calibration data needed for segment-mean SBC corrections.
Parameters
----------
rpv_cate_samples : (n, S)
Per-visitor posterior CATE samples (S = MCMC draws).
true_rpv_cate : (n,)
True CATE per visitor.
X : (n, d)
Feature matrix (same as used for tree fitting).
feature_names :
Column names for X.
depth :
Policy tree max depth.
levels :
Coverage levels to evaluate.
n_ptau_bins :
Number of bins for P(τ>0) calibration curve.
Returns
-------
dict
Keys include:
- ``seg_coverage_{level}`` — size-weighted segment-mean coverage
at each level.
- ``seg_n_segments`` — number of segments (leaves).
- ``seg_ptau_bin{k}_predicted``, ``seg_ptau_bin{k}_actual``,
``seg_ptau_bin{k}_n`` — per-bin segment-mean P(τ>0)
calibration data.
"""
from sklearn.tree import DecisionTreeClassifier
tau_hat = rpv_cate_samples.mean(axis=1)
labels = (tau_hat > 0).astype(int)
tree = DecisionTreeClassifier(max_depth=depth, min_samples_leaf=50)
tree.fit(X, labels)
leaf_ids = tree.apply(X)
unique_leaves = np.unique(leaf_ids)
result: dict[str, float] = {"seg_n_segments": len(unique_leaves)}
# Per-segment data for coverage and ptau
seg_sizes: list[int] = []
seg_coverages: dict[float, list[bool]] = {lvl: [] for lvl in levels}
seg_ptau_predicted: list[float] = []
seg_ptau_actual: list[float] = []
seg_idx = 0
for lid in unique_leaves:
mask = leaf_ids == lid
size = int(mask.sum())
if size < 10:
continue
seg_sizes.append(size)
# Segment-mean posterior: (S,) — mean CATE per posterior draw
seg_means = rpv_cate_samples[mask].mean(axis=0)
true_seg_mean = float(true_rpv_cate[mask].mean())
# Per-segment detail fields for z-score / error analysis
post_mean = float(seg_means.mean())
post_sd = float(seg_means.std())
ci90_lo = float(np.quantile(seg_means, 0.05))
ci90_hi = float(np.quantile(seg_means, 0.95))
result[f"seg_detail_{seg_idx}_size"] = size
result[f"seg_detail_{seg_idx}_true_cate"] = round(true_seg_mean, 6)
result[f"seg_detail_{seg_idx}_post_mean"] = round(post_mean, 6)
result[f"seg_detail_{seg_idx}_post_sd"] = round(post_sd, 6)
result[f"seg_detail_{seg_idx}_ci90_width"] = round(ci90_hi - ci90_lo, 6)
result[f"seg_detail_{seg_idx}_bias"] = round(post_mean - true_seg_mean, 6)
result[f"seg_detail_{seg_idx}_z_score"] = round(
(true_seg_mean - post_mean) / max(post_sd, 1e-10), 6,
)
seg_idx += 1
# Coverage: does CI at each level contain the true segment mean?
for lvl in levels:
alpha = (1 - lvl) / 2
ci_lo = float(np.quantile(seg_means, alpha))
ci_hi = float(np.quantile(seg_means, 1 - alpha))
seg_coverages[lvl].append(ci_lo <= true_seg_mean <= ci_hi)
# Segment-mean P(τ>0)
p_pos = float((seg_means > 0).mean())
truly_pos = 1.0 if true_seg_mean > 0 else 0.0
seg_ptau_predicted.append(p_pos)
seg_ptau_actual.append(truly_pos)
result["seg_n_valid_segments"] = seg_idx
# Size-weighted coverage per level
total_size = sum(seg_sizes)
for lvl in levels:
if not seg_coverages[lvl]:
continue
weighted_cov = sum(
c * s for c, s in zip(seg_coverages[lvl], seg_sizes, strict=False)
) / total_size
sfx = level_suffix(lvl)
result[f"seg_coverage_{sfx}"] = round(weighted_cov, 6)
result[f"seg_nominal_{sfx}"] = lvl
# Binned segment-mean P(τ>0) calibration
if seg_ptau_predicted:
pred_arr = np.array(seg_ptau_predicted)
act_arr = np.array(seg_ptau_actual)
bin_edges = np.linspace(0, 1, n_ptau_bins + 1)
bin_idx = np.clip(np.digitize(pred_arr, bin_edges) - 1, 0, n_ptau_bins - 1)
for b in range(n_ptau_bins):
bmask = bin_idx == b
n_in_bin = int(bmask.sum())
if n_in_bin < 2:
continue
result[f"seg_ptau_bin{b}_predicted"] = round(float(pred_arr[bmask].mean()), 6)
result[f"seg_ptau_bin{b}_actual"] = round(float(act_arr[bmask].mean()), 6)
result[f"seg_ptau_bin{b}_n"] = n_in_bin
return result
[docs]
def compute_miscalibration_area(
samples: np.ndarray,
truth: np.ndarray,
n_levels: int = 20,
*,
sorted_samples: np.ndarray | None = None,
) -> float:
r"""Integrated \|actual - nominal\| coverage across levels.
Returns a scalar in [0, 1]: 0 = perfectly calibrated, 1 = maximally off.
"""
levels = np.linspace(0.05, 0.95, n_levels)
if sorted_samples is not None:
alphas = []
for level in levels:
alpha = (1 - level) / 2
alphas.extend([alpha, 1 - alpha])
q_map = quantiles_from_sorted(sorted_samples, alphas)
deviations = []
for level in levels:
alpha = (1 - level) / 2
lo = q_map[alpha]
hi = q_map[1 - alpha]
actual = float(((truth >= lo) & (truth <= hi)).mean())
deviations.append(abs(actual - level))
return float(np.mean(deviations))
deviations = []
for level in levels:
alpha = (1 - level) / 2
lo = np.quantile(samples, alpha, axis=1)
hi = np.quantile(samples, 1 - alpha, axis=1)
actual = float(((truth >= lo) & (truth <= hi)).mean())
deviations.append(abs(actual - level))
return float(np.mean(deviations))
[docs]
def compute_channel_calibration(
diag: BCFDiagnosticData,
) -> dict[str, dict[str, float]] | None:
"""Separate coverage for binary, continuous, and composed channels.
Only available when full sample-level p0/p1 are present (joint estimator
or GPU with trace mode). Returns None in GPU summary mode.
"""
if diag.estimator not in ("joint", "gpu") or diag.true_p0 is None:
return None
# Need full (n, S) p0/p1 samples for coverage computation
if diag.p0_samples is None or diag.p1_samples is None:
return None
levels = (0.50, 0.80, 0.90, 0.95)
result: dict[str, dict[str, float]] = {}
# Binary channel: tau_b on probit scale → conversion CATE on probability scale
conv_cate_samples = diag.p1_samples - diag.p0_samples # (n, S)
true_conv_cate = diag.true_p1 - diag.true_p0
result["binary"] = compute_coverage(conv_cate_samples, true_conv_cate, levels)
# Composed RPV CATEs
true_rpv = diag.true_rpv_cate
if true_rpv is not None:
result["composed"] = compute_coverage(diag.rpv_cate_samples, true_rpv, levels)
return result
# ---------------------------------------------------------------------------
# Model critique (require ground truth)
# ---------------------------------------------------------------------------
[docs]
def compute_channel_attribution(diag: BCFDiagnosticData) -> dict[str, float] | None:
"""Which channel drives RPV CATE error?
Decomposes posterior mean RPV CATE error into:
- conversion channel: (p1-p0)*m0 contribution
- AOV channel: p1*(m1-m0) contribution
Only meaningful with ground truth and joint/gpu estimator.
Uses pre-computed means when available (GPU summary mode), otherwise
computes from full posterior samples.
"""
if (
diag.true_rpv_cate is None
or diag.true_p0 is None
or diag.true_m0 is None
):
return None
# Use pre-computed channel means if available (GPU summary mode)
if diag.p0_mean is not None and diag.sev0_mean is not None:
p0_hat = diag.p0_mean
p1_hat = diag.p1_mean
m0_hat = diag.sev0_mean
m1_hat = diag.sev1_mean
elif diag.estimator in ("joint",) and diag.p0_samples is not None:
# Full posterior samples available (CPU joint estimator)
assert diag.p1_samples is not None # populated in lockstep with p0_samples
p0_hat = diag.p0_samples.mean(axis=1)
p1_hat = diag.p1_samples.mean(axis=1)
b0, b1 = -0.5, 0.5
sev0_samples = np.exp(
diag.mu_cont_samples + b0 * diag.tau_cont_samples
+ diag.sigma2_samples[np.newaxis, :] / 2.0
)
sev1_samples = np.exp(
diag.mu_cont_samples + b1 * diag.tau_cont_samples
+ diag.sigma2_samples[np.newaxis, :] / 2.0
)
m0_hat = sev0_samples.mean(axis=1)
m1_hat = sev1_samples.mean(axis=1)
else:
# Proto or insufficient data: can't decompose severity
return None
# Estimated channel contributions
est_conv_channel = (p1_hat - p0_hat) * m0_hat # conversion lift
est_aov_channel = p1_hat * (m1_hat - m0_hat) # AOV lift
est_rpv = est_conv_channel + est_aov_channel
# True channel contributions
true_conv_channel = (diag.true_p1 - diag.true_p0) * diag.true_m0
true_aov_channel = diag.true_p1 * (diag.true_m1 - diag.true_m0)
# Per-channel RMSE
conv_rmse = float(np.sqrt(np.mean((est_conv_channel - true_conv_channel) ** 2)))
aov_rmse = float(np.sqrt(np.mean((est_aov_channel - true_aov_channel) ** 2)))
total_rmse = float(np.sqrt(np.mean((est_rpv - diag.true_rpv_cate) ** 2)))
# Per-channel bias
conv_bias = float((est_conv_channel - true_conv_channel).mean())
aov_bias = float((est_aov_channel - true_aov_channel).mean())
total_bias = float((est_rpv - diag.true_rpv_cate).mean())
return {
"conv_channel_rmse": conv_rmse,
"aov_channel_rmse": aov_rmse,
"total_rmse": total_rmse,
"conv_channel_bias": conv_bias,
"aov_channel_bias": aov_bias,
"total_bias": total_bias,
"conv_channel_frac": conv_rmse / (conv_rmse + aov_rmse) if (conv_rmse + aov_rmse) > 0 else np.nan,
}
[docs]
def compute_quintile_calibration(
diag: BCFDiagnosticData,
n_quantiles: int = 5,
*,
sorted_samples: np.ndarray | None = None,
est_mean: np.ndarray | None = None,
) -> list[dict[str, Any]] | None:
"""Per-quintile calibration: posterior mean vs truth, sign accuracy, coverage.
This is the primary diagnostic for our success criterion: accurate
population-level quintile rank ordering and sign correctness.
Parameters
----------
sorted_samples : (n, S), optional
Pre-sorted RPV CATE samples. Row-subsetting preserves sort order
so no re-sort is needed per quintile.
est_mean : (n,), optional
Pre-computed posterior mean to avoid recomputing O(nS).
"""
if diag.true_rpv_cate is None:
return None
est_mean_arr: np.ndarray = (
est_mean if est_mean is not None else diag.rpv_cate_samples.mean(axis=1)
)
truth = diag.true_rpv_cate
n = len(est_mean_arr)
# Rank by estimated CATE → assign quintiles
ranks = np.argsort(np.argsort(est_mean_arr))
q_ids = np.minimum(ranks * n_quantiles // n, n_quantiles - 1)
results = []
for q in range(n_quantiles):
mask = q_ids == q
size = int(mask.sum())
if size == 0:
continue
est_q = est_mean_arr[mask]
truth_q = truth[mask]
# Coverage at 90% for this quintile
if sorted_samples is not None:
sorted_q = sorted_samples[mask]
q_map = quantiles_from_sorted(sorted_q, [0.05, 0.95])
lo = q_map[0.05]
hi = q_map[0.95]
else:
samples_q = diag.rpv_cate_samples[mask]
lo = np.quantile(samples_q, 0.05, axis=1)
hi = np.quantile(samples_q, 0.95, axis=1)
coverage_90 = float(((truth_q >= lo) & (truth_q <= hi)).mean())
# Sign accuracy
sign_acc = float((np.sign(est_q) == np.sign(truth_q)).mean())
# Within-quintile rank correlation
if size > 3:
rho = spearmanr(est_q, truth_q).statistic # pyright: ignore[reportAttributeAccessIssue] — scipy stub gap
rank_corr = float(rho) if not np.isnan(rho) else 0.0
else:
rank_corr = np.nan
results.append({
"quintile": q + 1,
"n": size,
"est_mean": float(est_q.mean()),
"true_mean": float(truth_q.mean()),
"est_std": float(est_q.std()),
"true_std": float(truth_q.std()),
"sign_accuracy": sign_acc,
"coverage_90": coverage_90,
"rank_corr": rank_corr,
"frac_neg_true": float((truth_q < 0).mean()),
"frac_neg_est": float((est_q < 0).mean()),
})
return results
[docs]
def compute_decile_calibration(
diag: BCFDiagnosticData,
) -> list[dict[str, Any]] | None:
"""Finer-grained decile version of quintile calibration."""
return compute_quintile_calibration(diag, n_quantiles=10)
[docs]
def compute_selection_bias(diag: BCFDiagnosticData) -> dict[str, float] | None:
"""Does treatment change converter composition?
Compares E[log(Y)|convert, Z=1] vs E[log(Y)|convert, Z=0].
"""
treated_conv = diag.converter_mask & (diag.Z > 0.5)
control_conv = diag.converter_mask & (diag.Z < 0.5)
if treated_conv.sum() < 10 or control_conv.sum() < 10:
return None
log_y_treated = np.log(np.clip(diag.Y_obs[treated_conv], 1e-10, None))
log_y_control = np.log(np.clip(diag.Y_obs[control_conv], 1e-10, None))
return {
"mean_log_y_treated": float(log_y_treated.mean()),
"mean_log_y_control": float(log_y_control.mean()),
"diff": float(log_y_treated.mean() - log_y_control.mean()),
"n_treated_conv": int(treated_conv.sum()),
"n_control_conv": int(control_conv.sum()),
}
# ---------------------------------------------------------------------------
# Summary scorecard
# ---------------------------------------------------------------------------
def _split_chains(trace: np.ndarray, num_chains: int) -> list[np.ndarray]:
"""Split a concatenated trace into per-chain arrays.
Samples are stored as chain-0-all, chain-1-all, ... (not interleaved).
"""
if num_chains <= 1:
return [trace]
s_per_chain = len(trace) // num_chains
return [trace[i * s_per_chain:(i + 1) * s_per_chain] for i in range(num_chains)]
[docs]
def compute_chain_diagnostics(diag: BCFDiagnosticData) -> dict[str, Any]:
"""Per-chain convergence diagnostics for multi-chain runs.
Returns per-chain ESS, R-hat, autocorrelation at lag-1, and
per-chain σ² means for visual comparison.
"""
nc = diag.num_chains
chains = _split_chains(diag.sigma2_samples, nc)
per_chain_ess = [compute_ess(c) for c in chains]
per_chain_mean = [float(c.mean()) for c in chains]
per_chain_std = [float(c.std()) for c in chains]
# Autocorrelation at lag-1 per chain
per_chain_acf1 = []
for c in chains:
x = c - c.mean()
var = float(np.var(x))
if var < 1e-15 or len(c) < 2:
per_chain_acf1.append(0.0)
else:
acf1 = float(np.sum(x[:-1] * x[1:]) / (var * len(c)))
per_chain_acf1.append(acf1)
rhat = compute_rhat(chains) if nc >= 2 else float("nan")
return {
"num_chains": nc,
"rhat_sigma2": rhat,
"per_chain_ess": per_chain_ess,
"per_chain_mean": per_chain_mean,
"per_chain_std": per_chain_std,
"per_chain_acf1": per_chain_acf1,
"pooled_ess": compute_ess(diag.sigma2_samples),
"samples_per_chain": len(chains[0]),
}
[docs]
def compute_scorecard(
diag: BCFDiagnosticData,
*,
coverage: dict[str, float] | None = None,
miscalibration_area: float | None = None,
chain_diagnostics: dict[str, Any] | None = None,
est_mean: np.ndarray | None = None,
sorted_rpv: np.ndarray | None = None,
) -> dict[str, dict[str, Any]]:
"""Compute the summary scorecard with traffic-light grading.
Returns a dict of metric_name → {value, grade, threshold_info}.
Grade is "green", "yellow", or "red".
Accepts pre-computed values to avoid redundant computation when called
from render_report().
"""
scorecard: dict[str, dict[str, Any]] = {}
chain_diag = chain_diagnostics or compute_chain_diagnostics(diag)
nc = chain_diag["num_chains"]
# σ² convergence
s2_trace = compute_sigma2_trace(diag)
ratio = s2_trace["split_half_ratio"]
if 0.9 <= ratio <= 1.1:
grade = "green"
elif 0.8 <= ratio <= 1.2:
grade = "yellow"
else:
grade = "red"
scorecard["sigma2_split_half"] = {
"value": ratio, "grade": grade, "label": "σ² split-half ratio",
}
# R-hat (multi-chain only)
if nc >= 2:
rhat = chain_diag["rhat_sigma2"]
if rhat < 1.05:
grade = "green"
elif rhat < 1.10:
grade = "yellow"
else:
grade = "red"
scorecard["rhat_sigma2"] = {
"value": rhat, "grade": grade, "label": "σ² R-hat",
"detail": f"{nc} chains",
}
# Per-chain ESS — report the minimum across chains (weakest link)
per_ess = chain_diag["per_chain_ess"]
s_per_chain = chain_diag["samples_per_chain"]
min_ess = min(per_ess)
min_ess_frac = min_ess / s_per_chain if s_per_chain > 0 else 0.0
if min_ess_frac > 0.3:
grade = "green"
elif min_ess_frac > 0.1:
grade = "yellow"
else:
grade = "red"
ess_detail = (
f"min of {nc} chains, {min_ess_frac:.0%} of {s_per_chain}"
if nc > 1
else f"{min_ess_frac:.0%} of {s_per_chain} samples"
)
scorecard["sigma2_ess"] = {
"value": min_ess, "grade": grade, "label": "σ² ESS (min chain)",
"detail": ess_detail,
}
# Autocorrelation lag-1 — report max across chains (worst mixing)
max_acf1 = max(chain_diag["per_chain_acf1"])
if max_acf1 < 0.5:
grade = "green"
elif max_acf1 < 0.8:
grade = "yellow"
else:
grade = "red"
scorecard["sigma2_acf1"] = {
"value": max_acf1, "grade": grade, "label": "σ² lag-1 autocorr",
"detail": "max across chains" if nc > 1 else "",
}
# Calibration metrics (only with ground truth)
if diag.true_rpv_cate is not None:
# 95% coverage — use pre-computed if available
if coverage is not None:
cov = coverage
else:
cov = compute_coverage(
diag.rpv_cate_samples, diag.true_rpv_cate,
sorted_samples=sorted_rpv,
)
cov_95 = cov.get("coverage_95", np.nan)
if 0.90 <= cov_95 <= 0.98:
grade = "green"
elif 0.85 <= cov_95 <= 1.0:
grade = "yellow"
else:
grade = "red"
scorecard["coverage_95"] = {
"value": cov_95, "grade": grade, "label": "95% coverage",
}
# Rank correlation — use pre-computed est_mean if available
_est_mean = est_mean if est_mean is not None else diag.rpv_cate_samples.mean(axis=1)
rho = spearmanr(_est_mean, diag.true_rpv_cate).statistic # pyright: ignore[reportAttributeAccessIssue] — scipy stub gap
rho_val = float(rho) if not np.isnan(rho) else 0.0
if rho_val > 0.80:
grade = "green"
elif rho_val > 0.60:
grade = "yellow"
else:
grade = "red"
scorecard["rank_corr"] = {
"value": rho_val, "grade": grade, "label": "Rank correlation",
}
# Sign accuracy
sign_acc = float(
(np.sign(_est_mean) == np.sign(diag.true_rpv_cate)).mean()
)
if sign_acc > 0.80:
grade = "green"
elif sign_acc > 0.65:
grade = "yellow"
else:
grade = "red"
scorecard["sign_accuracy"] = {
"value": sign_acc, "grade": grade, "label": "Sign accuracy",
}
# Miscalibration area — use pre-computed if available
if miscalibration_area is not None:
misc_area = miscalibration_area
else:
misc_area = compute_miscalibration_area(
diag.rpv_cate_samples, diag.true_rpv_cate,
sorted_samples=sorted_rpv,
)
if misc_area < 0.05:
grade = "green"
elif misc_area < 0.15:
grade = "yellow"
else:
grade = "red"
scorecard["miscalibration_area"] = {
"value": misc_area, "grade": grade, "label": "Miscalibration area",
}
# RMSE
rmse = float(np.sqrt(np.mean((_est_mean - diag.true_rpv_cate) ** 2)))
scorecard["rmse"] = {
"value": rmse, "grade": "dim", "label": "CATE RMSE",
}
# Bias
bias = float((_est_mean - diag.true_rpv_cate).mean())
scorecard["bias"] = {
"value": bias, "grade": "dim", "label": "Mean bias",
}
return scorecard
# ---------------------------------------------------------------------------
# Slim diagnostics (fast path for sweep)
# ---------------------------------------------------------------------------
[docs]
def compute_slim_diagnostics(diag: BCFDiagnosticData) -> dict[str, Any]:
"""Fast diagnostic summary: pre-sort once, compute key metrics.
Returns a flat dict suitable for writing to summary.json in sweep mode.
Includes coverage at 7 levels, quintile breakdown, PPC scalars,
and channel calibration when available.
"""
if diag.true_rpv_cate is None:
return {"error": "no ground truth available"}
sorted_rpv = presort_samples(diag.rpv_cate_samples)
est_mean = diag.rpv_cate_samples.mean(axis=1)
truth = diag.true_rpv_cate
# Coverage at expanded levels (presort fast-path makes this free)
cov = compute_coverage(
diag.rpv_cate_samples, truth,
levels=(0.05, 0.10, 0.20, 0.50, 0.80, 0.90, 0.95, 0.975, 0.99, 0.995, 0.999, 0.9995),
sorted_samples=sorted_rpv,
)
# Rank correlation
rho = spearmanr(est_mean, truth).statistic # pyright: ignore[reportAttributeAccessIssue] — scipy stub gap
rank_rho = float(rho) if not np.isnan(rho) else 0.0
# Sign accuracy
sign_acc = float((np.sign(est_mean) == np.sign(truth)).mean())
# RMSE + bias
rmse = float(np.sqrt(np.mean((est_mean - truth) ** 2)))
bias = float((est_mean - truth).mean())
# Miscalibration area
miscal = compute_miscalibration_area(
diag.rpv_cate_samples, truth, sorted_samples=sorted_rpv,
)
# σ² convergence
s2_trace = compute_sigma2_trace(diag)
# Heterogeneity detection: across-visitor spread vs per-visitor uncertainty
est_spread = float(est_mean.std())
widths = diag.rpv_cate_samples.std(axis=1) # (n,) — reused for width distribution below
post_width = float(widths.mean())
hte_snr = est_spread / post_width if post_width > 1e-10 else 0.0
true_spread = float(truth.std())
# dict[str, Any] since some entries fall back to "" (sentinel for missing
# bins) when the corresponding data isn't present — see ptau_bin* +
# sign_boundary_coverage_90 paths below.
result: dict[str, Any] = {
**cov,
"rank_rho": rank_rho,
"sign_accuracy": sign_acc,
"rmse": rmse,
"bias": bias,
"miscalibration_area": miscal,
"sigma2_mean": s2_trace["mean"],
"sigma2_split_half_ratio": s2_trace["split_half_ratio"],
"est_spread": est_spread,
"post_width": post_width,
"hte_snr": hte_snr,
"true_spread": true_spread,
}
# Quintile breakdown — captures tail-vs-center miscalibration shape
quintiles = compute_quintile_calibration(
diag, sorted_samples=sorted_rpv, est_mean=est_mean,
)
if quintiles is not None:
# Pre-compute quintile assignments for per-quintile RMSE
ranks = np.argsort(np.argsort(est_mean))
n = len(est_mean)
q_ids = np.minimum(ranks * 5 // n, 4)
for qd in quintiles:
q = qd["quintile"]
mask = q_ids == (q - 1)
q_rmse = float(np.sqrt(np.mean((est_mean[mask] - truth[mask]) ** 2))) if mask.sum() > 0 else 0.0
result[f"q{q}_coverage_90"] = qd["coverage_90"]
result[f"q{q}_sign_accuracy"] = qd["sign_accuracy"]
result[f"q{q}_rmse"] = q_rmse
result[f"q{q}_bias"] = qd["est_mean"] - qd["true_mean"]
result[f"q{q}_est_mean"] = qd["est_mean"]
result[f"q{q}_true_mean"] = qd["true_mean"]
# --- SBC rank KS statistic ---
# Rank of truth within sorted posterior: if posteriors are calibrated,
# ranks should be Uniform(0, S). KS test detects systematic over/under-coverage.
n_samples = sorted_rpv.shape[1]
ranks = (sorted_rpv < truth[:, np.newaxis]).sum(axis=1) # (n,)
ks_stat, ks_pval = kstest(ranks, "uniform", args=(0, n_samples))
result["rank_ks_stat"] = float(ks_stat)
result["rank_ks_pval"] = float(ks_pval)
# --- P(τ>0) calibration curve (binned shape) ---
p_positive = (diag.rpv_cate_samples > 0).mean(axis=1)
truly_positive = (truth > 0).astype(float)
ptau_cal = compute_calibration_curve(p_positive, truly_positive, n_bins=10)
# Emit per-bin metrics for the isotonic fitting step
for k in range(10):
if k < len(ptau_cal["predicted"]):
result[f"ptau_bin{k}_predicted"] = ptau_cal["predicted"][k]
result[f"ptau_bin{k}_actual"] = ptau_cal["actual"][k]
result[f"ptau_bin{k}_n"] = ptau_cal["sizes"][k]
else:
result[f"ptau_bin{k}_predicted"] = ""
result[f"ptau_bin{k}_actual"] = ""
result[f"ptau_bin{k}_n"] = ""
# Summary scalars
if ptau_cal["predicted"]:
pred_arr = np.array(ptau_cal["predicted"])
act_arr = np.array(ptau_cal["actual"])
result["ptau_calibration_mae"] = float(np.abs(pred_arr - act_arr).mean())
result["ptau_calibration_bias"] = float((act_arr - pred_arr).mean())
else:
result["ptau_calibration_mae"] = ""
result["ptau_calibration_bias"] = ""
# --- Sign boundary coverage ---
# Near-zero band: |true τ| < median(|true τ|) / 2 — adaptive to effect scale
abs_truth = np.abs(truth)
threshold = float(np.median(abs_truth)) / 2.0
near_zero_mask = abs_truth < threshold
sign_boundary_n = int(near_zero_mask.sum())
result["sign_boundary_n"] = sign_boundary_n
if sign_boundary_n >= 20:
sorted_boundary = sorted_rpv[near_zero_mask]
q_map = quantiles_from_sorted(sorted_boundary, [0.05, 0.95])
lo = q_map[0.05]
hi = q_map[0.95]
boundary_truth = truth[near_zero_mask]
result["sign_boundary_coverage_90"] = float(
((boundary_truth >= lo) & (boundary_truth <= hi)).mean()
)
else:
result["sign_boundary_coverage_90"] = ""
# --- Per-channel individual-level coverage ---
chan_levels = (0.50, 0.80, 0.90, 0.95)
for label, samples, truth_arr in [
("p0", diag.p0_samples, diag.true_p0),
("p1", diag.p1_samples, diag.true_p1),
("sev0", diag.sev0_samples, diag.true_m0),
("sev1", diag.sev1_samples, diag.true_m1),
]:
if samples is not None and truth_arr is not None and samples.shape[1] > 0:
sorted_chan = presort_samples(samples)
chan_cov = compute_coverage(
samples, truth_arr, levels=chan_levels, sorted_samples=sorted_chan,
)
for key, val in chan_cov.items():
result[f"chan_{label}_{key}"] = val
chan_rmse = float(np.sqrt(np.mean((samples.mean(axis=1) - truth_arr) ** 2)))
result[f"chan_{label}_rmse"] = chan_rmse
# --- Interval width distribution ---
# widths already computed above for post_width/hte_snr
result["post_width_q25"] = float(np.quantile(widths, 0.25))
result["post_width_q75"] = float(np.quantile(widths, 0.75))
width_mean = widths.mean()
result["post_width_cv"] = float(widths.std() / width_mean) if width_mean > 1e-10 else 0.0
# PPC summary passthrough (GPU path stores PPC in gpu_diagnostics)
if diag.gpu_diagnostics is not None and "ppc" in diag.gpu_diagnostics:
ppc = diag.gpu_diagnostics["ppc"]
result["ppc_binary_mae"] = ppc.get("binary_mae", float("nan"))
result["ppc_continuous_mae"] = ppc.get("continuous_mae", float("nan"))
result["ppc_n_converters"] = ppc.get("n_converters", 0)
# Channel calibration (available when full sample-level p0/p1 exist)
chan_cal = compute_channel_calibration(diag)
if chan_cal is not None:
if "binary" in chan_cal:
result["binary_coverage_90"] = chan_cal["binary"].get("coverage_90", float("nan"))
result["binary_coverage_95"] = chan_cal["binary"].get("coverage_95", float("nan"))
if "composed" in chan_cal:
result["composed_coverage_90"] = chan_cal["composed"].get("coverage_90", float("nan"))
result["composed_coverage_95"] = chan_cal["composed"].get("coverage_95", float("nan"))
# Channel attribution (available when true_p0/p1/m0/m1 are provided)
chan_attr = compute_channel_attribution(diag)
if chan_attr is not None:
for k, v in chan_attr.items():
result[k] = v
# Topology mobility metrics — surfaced only when the producing fit
# retained per-iter topology hashes via GPUBCFConfig.retain_topology_history.
if diag.topology_history is not None:
from pytyche.bcf.diagnostics.topology import compute_topology_metrics # noqa: PLC0415
topology_metrics = compute_topology_metrics(diag.topology_history)
result.update(topology_metrics)
return result