Source code for pytyche.bcf.diagnostics

"""BCF posterior calibration diagnostics.

Structured diagnostic workflow for BCF model development, following
Betancourt's principled Bayesian workflow adapted for BART/BCF:

1. Computational faithfulness — σ² convergence, ESS, stability.
2. Retrodictive checks — posterior predictive vs observed.
3. Calibration — coverage, P(τ>0) calibration (requires ground truth).
4. Model critique — channel attribution, quintile calibration.

Public API
----------
- ``BCFDiagnosticData`` — structured container for posterior samples
  and metadata.
- ``extract_diagnostics_joint(model, ...)`` — populate from
  ``HurdleBCFModel``.
- ``extract_diagnostics_proto(result, ...)`` — populate from
  ``HurdleProtoResult``.
- ``extract_diagnostics_gpu(result, ...)`` — populate from
  ``HurdleBCFResult``.
- ``compute_*(...)`` — pure diagnostic computation functions.
- ``render_diagnostic_report(diag, run_dir)`` — terminal report to disk.
"""

from __future__ import annotations

import dataclasses
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import numpy as np
from scipy.stats import kstest, spearmanr
from scipy.stats import norm as sp_norm

if TYPE_CHECKING:
    from pytyche.bcf.diagnostics.topology import TopologyHistory

from pytyche.calibrate._coverage_levels import COVERAGE_LEVELS, level_suffix  # noqa: F401
from pytyche.diagnostics.convergence import (
    compute_ess,
    compute_posterior_stability,  # noqa: F401 — re-exported
    compute_rhat,
)

# ---------------------------------------------------------------------------
# Pre-sort helpers for fast quantile computation
# ---------------------------------------------------------------------------


[docs] def presort_samples(samples: np.ndarray) -> np.ndarray: """Sort posterior samples along axis=1 for O(1) quantile lookup. Pre-sorting once allows all subsequent quantile operations to use simple index lookup instead of re-sorting the full (n, S) array. """ return np.sort(samples, axis=1)
[docs] def quantiles_from_sorted( sorted_samples: np.ndarray, alphas: Sequence[float], ) -> dict[float, np.ndarray]: """Compute all quantiles in one fancy-index op from pre-sorted (n, S). Uses floor indexing (equivalent to numpy method='lower'). For boolean coverage checks the interpolation method is immaterial. """ S = sorted_samples.shape[1] indices = [min(int(a * S), S - 1) for a in alphas] all_q = sorted_samples[:, indices] # (n, len(alphas)) return {a: all_q[:, i] for i, a in enumerate(alphas)}
# --------------------------------------------------------------------------- # Diagnostic data container # ---------------------------------------------------------------------------
[docs] @dataclasses.dataclass class BCFDiagnosticData: """Structured container for all posterior samples from a BCF fit. Normalizes both estimator APIs (joint, proto, gpu) into a common format for downstream diagnostic functions. Per-visitor posterior samples have shape (n, S) where n is the number of visitors and S is the number of retained MCMC samples. Channel convention: - Binary channel: probit scale (mu_b, tau_b → Phi(mu_b + b*tau_b)) - Continuous channel: log-revenue scale (mu_c, tau_c) - Composed: RPV CATE in $/visitor """ # Per-visitor posterior samples (n, S) tau_binary_samples: np.ndarray # probit-scale binary treatment effect tau_cont_samples: np.ndarray # log-scale continuous treatment effect rpv_cate_samples: np.ndarray # composed RPV CATE samples ($/visitor) mu_binary_samples: np.ndarray # probit-scale prognostic mu_cont_samples: np.ndarray # log-scale prognostic # Per-visitor derived channel samples (n, S) — None if not retained p0_samples: np.ndarray | None # P(convert | control) per sample p1_samples: np.ndarray | None # P(convert | treated) per sample # Scalar traces (S,) sigma2_samples: np.ndarray # global variance trace # Observed data Z: np.ndarray # (n,) treatment assignment Y_obs: np.ndarray # (n,) observed revenue converter_mask: np.ndarray # (n,) boolean converter indicator # Metadata n_burn: int # burn-in iterations already discarded estimator: str # "joint" | "proto" | "gpu" num_chains: int = 1 # chain count (samples are interleaved) # Per-visitor severity channel samples (n, S) — None if not retained sev0_samples: np.ndarray | None = None # E[sev | control, convert] per sample sev1_samples: np.ndarray | None = None # E[sev | treated, convert] per sample # Per-visitor propensity (optional — needed for conformal calibration) propensity: np.ndarray | None = None # (n,) P(Z=1|X) # Channel-level posterior means (n,) — populated by GPU path (summary mode) p0_mean: np.ndarray | None = None # E[Φ(μ_b + b₀·τ_b)] p1_mean: np.ndarray | None = None # E[Φ(μ_b + b₁·τ_b)] sev0_mean: np.ndarray | None = None # E[exp(μ_c + b₀·τ_c + σ²/2)] sev1_mean: np.ndarray | None = None # E[exp(μ_c + b₁·τ_c + σ²/2)] # GPU-path diagnostics (PPC, per-chain stats computed during scan) gpu_diagnostics: dict | None = None # Ground truth (optional — available in simulation) true_rpv_cate: np.ndarray | None = None true_p0: np.ndarray | None = None true_p1: np.ndarray | None = None true_m0: np.ndarray | None = None true_m1: np.ndarray | None = None cluster_id: np.ndarray | None = None # Optional retained topology trace (populated when the producing fit was # configured with GPUBCFConfig.retain_topology_history=True). When None # (default), compute_slim_diagnostics emits no topology.* keys. topology_history: TopologyHistory | None = None
# --------------------------------------------------------------------------- # Extraction: Joint hurdle BCF (HurdleBCFModel) # ---------------------------------------------------------------------------
[docs] def extract_diagnostics_joint( model: Any, Z: np.ndarray, Y_obs: np.ndarray, num_gfr: int, num_burnin: int, *, true_rpv_cate: np.ndarray | None = None, true_p0: np.ndarray | None = None, true_p1: np.ndarray | None = None, true_m0: np.ndarray | None = None, true_m1: np.ndarray | None = None, cluster_id: np.ndarray | None = None, ) -> BCFDiagnosticData: """Extract BCFDiagnosticData from a fitted HurdleBCFModel. Parameters ---------- model : HurdleBCFModel Fitted joint hurdle model with posterior arrays. Z : (n,) Treatment assignment. Y_obs : (n,) Observed revenue. num_gfr : Number of GFR warmstart iterations to discard. num_burnin : Number of additional burn-in iterations to discard. """ burn_in = num_gfr + num_burnin mu_b = model.mu_binary_samples[:, burn_in:] # (n, S) mu_c = model.mu_cont_samples[:, burn_in:] tau_b = model.tau_binary_samples[:, burn_in:] tau_c = model.tau_cont_samples[:, burn_in:] s2 = model.sigma2_samples[burn_in:] # (S,) # Conversion probabilities — adaptive coding basis b0, b1 = -0.5, 0.5 p0 = sp_norm.cdf(mu_b + b0 * tau_b) p1 = sp_norm.cdf(mu_b + b1 * tau_b) # Expected severity via lognormal identity sev0 = np.exp(mu_c + b0 * tau_c + s2[np.newaxis, :] / 2.0) sev1 = np.exp(mu_c + b1 * tau_c + s2[np.newaxis, :] / 2.0) rpv_cate_samples = p1 * sev1 - p0 * sev0 return BCFDiagnosticData( tau_binary_samples=tau_b, tau_cont_samples=tau_c, rpv_cate_samples=rpv_cate_samples, mu_binary_samples=mu_b, mu_cont_samples=mu_c, p0_samples=p0, p1_samples=p1, sigma2_samples=s2, Z=Z, Y_obs=Y_obs, converter_mask=(Y_obs > 0), n_burn=burn_in, estimator="joint", true_rpv_cate=true_rpv_cate, true_p0=true_p0, true_p1=true_p1, true_m0=true_m0, true_m1=true_m1, cluster_id=cluster_id, )
# --------------------------------------------------------------------------- # Extraction: Proto hurdle BCF (HurdleProtoResult) # ---------------------------------------------------------------------------
[docs] def extract_diagnostics_proto( result: Any, X: np.ndarray, Z: np.ndarray, Y_obs: np.ndarray, config: Any, *, true_rpv_cate: np.ndarray | None = None, true_p0: np.ndarray | None = None, true_p1: np.ndarray | None = None, true_m0: np.ndarray | None = None, true_m1: np.ndarray | None = None, cluster_id: np.ndarray | None = None, ) -> BCFDiagnosticData: """Extract BCFDiagnosticData from a HurdleProtoResult. The proto stores f_conv_samples (composite probit prediction) and f_sev_samples (composite log-severity prediction). We decompose these back into mu/tau components using the stored tau_alpha/tau_beta arrays (which are stored on the result but need to be accessed from the rpv_cate_samples computation path). Since the proto result doesn't expose tau_alpha_mcmc / tau_beta_mcmc directly, we derive mu/tau from the composite predictions (``f_conv = mu_alpha + tau_basis * tau_alpha``); we need to run the extraction inline, or accept the composite. For the proto, we work with the composite predictions and derive channel-level quantities via finite-difference (``tau_binary`` is approximated from ``f_conv`` evaluated at Z=1 minus ``f_conv`` evaluated at Z=0). """ # Proto stores: f_conv_samples (n, S), f_sev_samples (n, S), sigma2_samples (S,) f_conv = result.f_conv_samples # (n, S) f_sev = result.f_sev_samples # (n, S) sigma2 = result.sigma2_samples # (S,) b0, b1 = -0.5, 0.5 tau_basis = (1 - Z) * b0 + Z * b1 # The proto doesn't separate mu from tau in its result object. # For diagnostics, we use the composite predictions as-is: # mu_binary ≈ f_conv (the composite probit prediction) # tau_binary/tau_cont are not separable without internal forest access. # We mark these as composite predictions and adjust diagnostics accordingly. # # However, rpv_cate_samples IS available with full posterior: rpv_cate_samples = result.rpv_cate_samples # (n, S) # For the proto, set tau_binary/tau_cont to zeros since we can't decompose. # Channel-level diagnostics will be skipped for proto. n, S = f_conv.shape tau_b_placeholder = np.zeros((n, S)) tau_c_placeholder = np.zeros((n, S)) # Conversion probabilities from composite p0 = sp_norm.cdf(f_conv - tau_basis[:, np.newaxis] * 0.5) # approx: set Z=0 p1 = sp_norm.cdf(f_conv + (1 - tau_basis[:, np.newaxis]) * 0.5) # approx: set Z=1 # More precise: p0 = Phi(mu_alpha + b0 * tau_alpha), p1 = Phi(mu_alpha + b1 * tau_alpha) # Since f_conv[i] = mu_alpha[i] + tau_basis[i] * tau_alpha[i], and tau_basis depends on Z: # For control (Z=0): f_conv = mu + b0*tau, so mu + b1*tau = f_conv + (b1-b0)*tau # For treated (Z=1): f_conv = mu + b1*tau, so mu + b0*tau = f_conv + (b0-b1)*tau # We can't extract tau without the forests. Use the rpv_cate_samples directly. return BCFDiagnosticData( tau_binary_samples=tau_b_placeholder, tau_cont_samples=tau_c_placeholder, rpv_cate_samples=rpv_cate_samples, mu_binary_samples=f_conv, mu_cont_samples=f_sev, p0_samples=p0, p1_samples=p1, sigma2_samples=sigma2, Z=Z, Y_obs=Y_obs, converter_mask=(Y_obs > 0), n_burn=config.num_gfr, estimator="proto", true_rpv_cate=true_rpv_cate, true_p0=true_p0, true_p1=true_p1, true_m0=true_m0, true_m1=true_m1, cluster_id=cluster_id, )
# --------------------------------------------------------------------------- # Extraction: GPU joint hurdle BCF (HurdleBCFResult) # ---------------------------------------------------------------------------
[docs] def extract_diagnostics_gpu( result: Any, Z: np.ndarray, Y_obs: np.ndarray, *, true_rpv_cate: np.ndarray | None = None, true_p0: np.ndarray | None = None, true_p1: np.ndarray | None = None, true_m0: np.ndarray | None = None, true_m1: np.ndarray | None = None, cluster_id: np.ndarray | None = None, ) -> BCFDiagnosticData: """Extract BCFDiagnosticData from a HurdleBCFResult. The result stores pre-composed RPV CATE samples and channel-level posterior means (composed on GPU, float32). No CPU-side recomposition of full (n, S) channel matrices is needed. Parameters ---------- result : HurdleBCFResult Fitted GPU joint hurdle model. Z : (n,) Treatment assignment. Y_obs : (n,) Observed revenue. """ n = len(Z) # Placeholder zero arrays for channel-level sample fields not available # in summary mode. Downstream functions guard on p0_samples is None. zeros_ns = np.zeros((n, 0), dtype=np.float32) return BCFDiagnosticData( tau_binary_samples=zeros_ns, tau_cont_samples=zeros_ns, rpv_cate_samples=result.rpv_cate_samples, mu_binary_samples=zeros_ns, mu_cont_samples=zeros_ns, p0_samples=result.p0_samples, p1_samples=result.p1_samples, sev0_samples=result.sev0_samples, sev1_samples=result.sev1_samples, sigma2_samples=result.sigma2_samples, Z=Z, Y_obs=Y_obs, converter_mask=(Y_obs > 0), n_burn=0, # burn-in already discarded by the scan estimator="gpu", num_chains=result.num_chains, gpu_diagnostics=result.diagnostics, p0_mean=result.p0_mean, p1_mean=result.p1_mean, sev0_mean=result.sev0_mean, sev1_mean=result.sev1_mean, true_rpv_cate=true_rpv_cate, true_p0=true_p0, true_p1=true_p1, true_m0=true_m0, true_m1=true_m1, cluster_id=cluster_id, topology_history=getattr(result, "topology_history", None), )
# --------------------------------------------------------------------------- # Computational faithfulness diagnostics # ---------------------------------------------------------------------------
[docs] def compute_sigma2_trace(diag: BCFDiagnosticData) -> dict[str, float]: """Convergence summary for the global variance trace. Returns mean, final value, split-half ratio (first-half mean / second-half mean), and trend (linear slope normalized by mean). """ s2 = diag.sigma2_samples S = len(s2) half = S // 2 first_half_mean = float(s2[:half].mean()) second_half_mean = float(s2[half:].mean()) overall_mean = float(s2.mean()) # Split-half ratio: should be near 1.0 if converged ratio = first_half_mean / second_half_mean if second_half_mean > 0 else np.nan # Linear trend: slope / mean x = np.arange(S, dtype=float) slope = float(np.polyfit(x, s2, 1)[0]) trend = slope * S / overall_mean if overall_mean > 0 else np.nan return { "mean": overall_mean, "final": float(s2[-1]), "split_half_ratio": ratio, "trend": trend, "n_samples": S, }
# --------------------------------------------------------------------------- # Retrodictive checks (posterior predictive) # ---------------------------------------------------------------------------
[docs] def compute_ppc_binary(diag: BCFDiagnosticData, n_bins: int = 10) -> dict[str, Any] | None: """Posterior predicted vs observed conversion by decile of predicted prob. Bins visitors by posterior mean P(convert), computes observed conversion rate in each bin, and compares to the predicted rate. Returns None when sample-level p0/p1 are unavailable (GPU summary mode). """ if diag.p0_samples is None or diag.p1_samples is None: return None # Posterior mean conversion probability per visitor p_mean = np.where( diag.Z[:, np.newaxis], diag.p1_samples, diag.p0_samples, ).mean(axis=1) # (n,) observed_conv = diag.converter_mask.astype(float) # Bin by predicted probability bin_edges = np.quantile(p_mean, np.linspace(0, 1, n_bins + 1)) bin_edges[-1] += 1e-10 # include max bins = np.digitize(p_mean, bin_edges) - 1 bins = np.clip(bins, 0, n_bins - 1) predicted_rates = [] observed_rates = [] bin_sizes = [] for b in range(n_bins): mask = bins == b if mask.sum() == 0: continue predicted_rates.append(float(p_mean[mask].mean())) observed_rates.append(float(observed_conv[mask].mean())) bin_sizes.append(int(mask.sum())) return { "predicted_rates": predicted_rates, "observed_rates": observed_rates, "bin_sizes": bin_sizes, }
[docs] def compute_ppc_continuous(diag: BCFDiagnosticData) -> dict[str, Any] | None: """Posterior predicted vs observed log-revenue for converters. Returns None when channel-level samples are unavailable (GPU summary mode). """ mask = diag.converter_mask if mask.sum() == 0: return {"observed_log_y": [], "predicted_log_y_mean": np.nan} # Need channel-level samples for PPC — not available in GPU summary mode if diag.mu_cont_samples.shape[1] == 0: return None observed_log_y = np.log(np.clip(diag.Y_obs[mask], 1e-10, None)) # Predicted log-revenue: mu_c + basis * tau_c for the assigned arm b0, b1 = -0.5, 0.5 if diag.estimator in ("joint", "gpu"): basis = np.where(diag.Z, b1, b0) pred_log_y = ( diag.mu_cont_samples[mask] + basis[mask, np.newaxis] * diag.tau_cont_samples[mask] ).mean(axis=1) else: # Proto: mu_cont_samples IS the composite f_sev pred_log_y = diag.mu_cont_samples[mask].mean(axis=1) return { "observed_log_y": observed_log_y.tolist(), "predicted_log_y": pred_log_y.tolist(), "observed_mean": float(observed_log_y.mean()), "predicted_mean": float(pred_log_y.mean()), }
# --------------------------------------------------------------------------- # Calibration diagnostics (require ground truth) # ---------------------------------------------------------------------------
[docs] def compute_coverage( samples: np.ndarray, truth: np.ndarray, levels: tuple[float, ...] = (0.50, 0.80, 0.90, 0.95), *, sorted_samples: np.ndarray | None = None, ) -> dict[str, float]: """Actual vs nominal coverage at given credible interval levels. Parameters ---------- samples : (n, S) Posterior samples per visitor. truth : (n,) True values. levels : Nominal coverage levels. sorted_samples : (n, S), optional Pre-sorted samples for fast quantile lookup. When provided, ``samples`` is ignored for quantile computation. """ if sorted_samples is not None: alphas = [] for level in levels: alpha = (1 - level) / 2 alphas.extend([alpha, 1 - alpha]) q_map = quantiles_from_sorted(sorted_samples, alphas) result = {} for level in levels: alpha = (1 - level) / 2 lo = q_map[alpha] hi = q_map[1 - alpha] covered = (truth >= lo) & (truth <= hi) sfx = level_suffix(level) result[f"coverage_{sfx}"] = float(covered.mean()) result[f"nominal_{sfx}"] = level return result result = {} for level in levels: alpha = (1 - level) / 2 lo = np.quantile(samples, alpha, axis=1) hi = np.quantile(samples, 1 - alpha, axis=1) covered = (truth >= lo) & (truth <= hi) sfx = level_suffix(level) result[f"coverage_{sfx}"] = float(covered.mean()) result[f"nominal_{sfx}"] = level return result
[docs] def compute_calibration_curve( p_positive: np.ndarray, truly_positive: np.ndarray, n_bins: int = 10, ) -> dict[str, list[float]]: """Binned P(τ>0) vs actual fraction positive. Parameters ---------- p_positive : (n,) Posterior P(τ>0) per visitor. truly_positive : (n,) Boolean: is true τ > 0? n_bins : Number of bins. """ bin_edges = np.linspace(0, 1, n_bins + 1) bin_idx = np.digitize(p_positive, bin_edges) - 1 bin_idx = np.clip(bin_idx, 0, n_bins - 1) predicted = [] actual = [] sizes = [] for b in range(n_bins): mask = bin_idx == b if mask.sum() < 5: continue predicted.append(float(p_positive[mask].mean())) actual.append(float(truly_positive[mask].mean())) sizes.append(int(mask.sum())) return {"predicted": predicted, "actual": actual, "sizes": sizes}
[docs] def compute_segment_diagnostics( rpv_cate_samples: np.ndarray, true_rpv_cate: np.ndarray, X: np.ndarray, feature_names: list[str], *, depth: int = 3, levels: tuple[float, ...] = (0.05, 0.10, 0.20, 0.50, 0.80, 0.90, 0.95, 0.975, 0.99, 0.995, 0.999, 0.9995), n_ptau_bins: int = 10, ) -> dict[str, float]: """Segment-level coverage and P(τ>0) calibration. Fits a policy tree on the BCF posterior means, then for each leaf segment computes coverage and P(τ>0) accuracy at the *segment-mean* level. This produces the calibration data needed for segment-mean SBC corrections. Parameters ---------- rpv_cate_samples : (n, S) Per-visitor posterior CATE samples (S = MCMC draws). true_rpv_cate : (n,) True CATE per visitor. X : (n, d) Feature matrix (same as used for tree fitting). feature_names : Column names for X. depth : Policy tree max depth. levels : Coverage levels to evaluate. n_ptau_bins : Number of bins for P(τ>0) calibration curve. Returns ------- dict Keys include: - ``seg_coverage_{level}`` — size-weighted segment-mean coverage at each level. - ``seg_n_segments`` — number of segments (leaves). - ``seg_ptau_bin{k}_predicted``, ``seg_ptau_bin{k}_actual``, ``seg_ptau_bin{k}_n`` — per-bin segment-mean P(τ>0) calibration data. """ from sklearn.tree import DecisionTreeClassifier tau_hat = rpv_cate_samples.mean(axis=1) labels = (tau_hat > 0).astype(int) tree = DecisionTreeClassifier(max_depth=depth, min_samples_leaf=50) tree.fit(X, labels) leaf_ids = tree.apply(X) unique_leaves = np.unique(leaf_ids) result: dict[str, float] = {"seg_n_segments": len(unique_leaves)} # Per-segment data for coverage and ptau seg_sizes: list[int] = [] seg_coverages: dict[float, list[bool]] = {lvl: [] for lvl in levels} seg_ptau_predicted: list[float] = [] seg_ptau_actual: list[float] = [] seg_idx = 0 for lid in unique_leaves: mask = leaf_ids == lid size = int(mask.sum()) if size < 10: continue seg_sizes.append(size) # Segment-mean posterior: (S,) — mean CATE per posterior draw seg_means = rpv_cate_samples[mask].mean(axis=0) true_seg_mean = float(true_rpv_cate[mask].mean()) # Per-segment detail fields for z-score / error analysis post_mean = float(seg_means.mean()) post_sd = float(seg_means.std()) ci90_lo = float(np.quantile(seg_means, 0.05)) ci90_hi = float(np.quantile(seg_means, 0.95)) result[f"seg_detail_{seg_idx}_size"] = size result[f"seg_detail_{seg_idx}_true_cate"] = round(true_seg_mean, 6) result[f"seg_detail_{seg_idx}_post_mean"] = round(post_mean, 6) result[f"seg_detail_{seg_idx}_post_sd"] = round(post_sd, 6) result[f"seg_detail_{seg_idx}_ci90_width"] = round(ci90_hi - ci90_lo, 6) result[f"seg_detail_{seg_idx}_bias"] = round(post_mean - true_seg_mean, 6) result[f"seg_detail_{seg_idx}_z_score"] = round( (true_seg_mean - post_mean) / max(post_sd, 1e-10), 6, ) seg_idx += 1 # Coverage: does CI at each level contain the true segment mean? for lvl in levels: alpha = (1 - lvl) / 2 ci_lo = float(np.quantile(seg_means, alpha)) ci_hi = float(np.quantile(seg_means, 1 - alpha)) seg_coverages[lvl].append(ci_lo <= true_seg_mean <= ci_hi) # Segment-mean P(τ>0) p_pos = float((seg_means > 0).mean()) truly_pos = 1.0 if true_seg_mean > 0 else 0.0 seg_ptau_predicted.append(p_pos) seg_ptau_actual.append(truly_pos) result["seg_n_valid_segments"] = seg_idx # Size-weighted coverage per level total_size = sum(seg_sizes) for lvl in levels: if not seg_coverages[lvl]: continue weighted_cov = sum( c * s for c, s in zip(seg_coverages[lvl], seg_sizes, strict=False) ) / total_size sfx = level_suffix(lvl) result[f"seg_coverage_{sfx}"] = round(weighted_cov, 6) result[f"seg_nominal_{sfx}"] = lvl # Binned segment-mean P(τ>0) calibration if seg_ptau_predicted: pred_arr = np.array(seg_ptau_predicted) act_arr = np.array(seg_ptau_actual) bin_edges = np.linspace(0, 1, n_ptau_bins + 1) bin_idx = np.clip(np.digitize(pred_arr, bin_edges) - 1, 0, n_ptau_bins - 1) for b in range(n_ptau_bins): bmask = bin_idx == b n_in_bin = int(bmask.sum()) if n_in_bin < 2: continue result[f"seg_ptau_bin{b}_predicted"] = round(float(pred_arr[bmask].mean()), 6) result[f"seg_ptau_bin{b}_actual"] = round(float(act_arr[bmask].mean()), 6) result[f"seg_ptau_bin{b}_n"] = n_in_bin return result
[docs] def compute_miscalibration_area( samples: np.ndarray, truth: np.ndarray, n_levels: int = 20, *, sorted_samples: np.ndarray | None = None, ) -> float: r"""Integrated \|actual - nominal\| coverage across levels. Returns a scalar in [0, 1]: 0 = perfectly calibrated, 1 = maximally off. """ levels = np.linspace(0.05, 0.95, n_levels) if sorted_samples is not None: alphas = [] for level in levels: alpha = (1 - level) / 2 alphas.extend([alpha, 1 - alpha]) q_map = quantiles_from_sorted(sorted_samples, alphas) deviations = [] for level in levels: alpha = (1 - level) / 2 lo = q_map[alpha] hi = q_map[1 - alpha] actual = float(((truth >= lo) & (truth <= hi)).mean()) deviations.append(abs(actual - level)) return float(np.mean(deviations)) deviations = [] for level in levels: alpha = (1 - level) / 2 lo = np.quantile(samples, alpha, axis=1) hi = np.quantile(samples, 1 - alpha, axis=1) actual = float(((truth >= lo) & (truth <= hi)).mean()) deviations.append(abs(actual - level)) return float(np.mean(deviations))
[docs] def compute_channel_calibration( diag: BCFDiagnosticData, ) -> dict[str, dict[str, float]] | None: """Separate coverage for binary, continuous, and composed channels. Only available when full sample-level p0/p1 are present (joint estimator or GPU with trace mode). Returns None in GPU summary mode. """ if diag.estimator not in ("joint", "gpu") or diag.true_p0 is None: return None # Need full (n, S) p0/p1 samples for coverage computation if diag.p0_samples is None or diag.p1_samples is None: return None levels = (0.50, 0.80, 0.90, 0.95) result: dict[str, dict[str, float]] = {} # Binary channel: tau_b on probit scale → conversion CATE on probability scale conv_cate_samples = diag.p1_samples - diag.p0_samples # (n, S) true_conv_cate = diag.true_p1 - diag.true_p0 result["binary"] = compute_coverage(conv_cate_samples, true_conv_cate, levels) # Composed RPV CATEs true_rpv = diag.true_rpv_cate if true_rpv is not None: result["composed"] = compute_coverage(diag.rpv_cate_samples, true_rpv, levels) return result
# --------------------------------------------------------------------------- # Model critique (require ground truth) # ---------------------------------------------------------------------------
[docs] def compute_channel_attribution(diag: BCFDiagnosticData) -> dict[str, float] | None: """Which channel drives RPV CATE error? Decomposes posterior mean RPV CATE error into: - conversion channel: (p1-p0)*m0 contribution - AOV channel: p1*(m1-m0) contribution Only meaningful with ground truth and joint/gpu estimator. Uses pre-computed means when available (GPU summary mode), otherwise computes from full posterior samples. """ if ( diag.true_rpv_cate is None or diag.true_p0 is None or diag.true_m0 is None ): return None # Use pre-computed channel means if available (GPU summary mode) if diag.p0_mean is not None and diag.sev0_mean is not None: p0_hat = diag.p0_mean p1_hat = diag.p1_mean m0_hat = diag.sev0_mean m1_hat = diag.sev1_mean elif diag.estimator in ("joint",) and diag.p0_samples is not None: # Full posterior samples available (CPU joint estimator) assert diag.p1_samples is not None # populated in lockstep with p0_samples p0_hat = diag.p0_samples.mean(axis=1) p1_hat = diag.p1_samples.mean(axis=1) b0, b1 = -0.5, 0.5 sev0_samples = np.exp( diag.mu_cont_samples + b0 * diag.tau_cont_samples + diag.sigma2_samples[np.newaxis, :] / 2.0 ) sev1_samples = np.exp( diag.mu_cont_samples + b1 * diag.tau_cont_samples + diag.sigma2_samples[np.newaxis, :] / 2.0 ) m0_hat = sev0_samples.mean(axis=1) m1_hat = sev1_samples.mean(axis=1) else: # Proto or insufficient data: can't decompose severity return None # Estimated channel contributions est_conv_channel = (p1_hat - p0_hat) * m0_hat # conversion lift est_aov_channel = p1_hat * (m1_hat - m0_hat) # AOV lift est_rpv = est_conv_channel + est_aov_channel # True channel contributions true_conv_channel = (diag.true_p1 - diag.true_p0) * diag.true_m0 true_aov_channel = diag.true_p1 * (diag.true_m1 - diag.true_m0) # Per-channel RMSE conv_rmse = float(np.sqrt(np.mean((est_conv_channel - true_conv_channel) ** 2))) aov_rmse = float(np.sqrt(np.mean((est_aov_channel - true_aov_channel) ** 2))) total_rmse = float(np.sqrt(np.mean((est_rpv - diag.true_rpv_cate) ** 2))) # Per-channel bias conv_bias = float((est_conv_channel - true_conv_channel).mean()) aov_bias = float((est_aov_channel - true_aov_channel).mean()) total_bias = float((est_rpv - diag.true_rpv_cate).mean()) return { "conv_channel_rmse": conv_rmse, "aov_channel_rmse": aov_rmse, "total_rmse": total_rmse, "conv_channel_bias": conv_bias, "aov_channel_bias": aov_bias, "total_bias": total_bias, "conv_channel_frac": conv_rmse / (conv_rmse + aov_rmse) if (conv_rmse + aov_rmse) > 0 else np.nan, }
[docs] def compute_quintile_calibration( diag: BCFDiagnosticData, n_quantiles: int = 5, *, sorted_samples: np.ndarray | None = None, est_mean: np.ndarray | None = None, ) -> list[dict[str, Any]] | None: """Per-quintile calibration: posterior mean vs truth, sign accuracy, coverage. This is the primary diagnostic for our success criterion: accurate population-level quintile rank ordering and sign correctness. Parameters ---------- sorted_samples : (n, S), optional Pre-sorted RPV CATE samples. Row-subsetting preserves sort order so no re-sort is needed per quintile. est_mean : (n,), optional Pre-computed posterior mean to avoid recomputing O(nS). """ if diag.true_rpv_cate is None: return None est_mean_arr: np.ndarray = ( est_mean if est_mean is not None else diag.rpv_cate_samples.mean(axis=1) ) truth = diag.true_rpv_cate n = len(est_mean_arr) # Rank by estimated CATE → assign quintiles ranks = np.argsort(np.argsort(est_mean_arr)) q_ids = np.minimum(ranks * n_quantiles // n, n_quantiles - 1) results = [] for q in range(n_quantiles): mask = q_ids == q size = int(mask.sum()) if size == 0: continue est_q = est_mean_arr[mask] truth_q = truth[mask] # Coverage at 90% for this quintile if sorted_samples is not None: sorted_q = sorted_samples[mask] q_map = quantiles_from_sorted(sorted_q, [0.05, 0.95]) lo = q_map[0.05] hi = q_map[0.95] else: samples_q = diag.rpv_cate_samples[mask] lo = np.quantile(samples_q, 0.05, axis=1) hi = np.quantile(samples_q, 0.95, axis=1) coverage_90 = float(((truth_q >= lo) & (truth_q <= hi)).mean()) # Sign accuracy sign_acc = float((np.sign(est_q) == np.sign(truth_q)).mean()) # Within-quintile rank correlation if size > 3: rho = spearmanr(est_q, truth_q).statistic # pyright: ignore[reportAttributeAccessIssue] — scipy stub gap rank_corr = float(rho) if not np.isnan(rho) else 0.0 else: rank_corr = np.nan results.append({ "quintile": q + 1, "n": size, "est_mean": float(est_q.mean()), "true_mean": float(truth_q.mean()), "est_std": float(est_q.std()), "true_std": float(truth_q.std()), "sign_accuracy": sign_acc, "coverage_90": coverage_90, "rank_corr": rank_corr, "frac_neg_true": float((truth_q < 0).mean()), "frac_neg_est": float((est_q < 0).mean()), }) return results
[docs] def compute_decile_calibration( diag: BCFDiagnosticData, ) -> list[dict[str, Any]] | None: """Finer-grained decile version of quintile calibration.""" return compute_quintile_calibration(diag, n_quantiles=10)
[docs] def compute_selection_bias(diag: BCFDiagnosticData) -> dict[str, float] | None: """Does treatment change converter composition? Compares E[log(Y)|convert, Z=1] vs E[log(Y)|convert, Z=0]. """ treated_conv = diag.converter_mask & (diag.Z > 0.5) control_conv = diag.converter_mask & (diag.Z < 0.5) if treated_conv.sum() < 10 or control_conv.sum() < 10: return None log_y_treated = np.log(np.clip(diag.Y_obs[treated_conv], 1e-10, None)) log_y_control = np.log(np.clip(diag.Y_obs[control_conv], 1e-10, None)) return { "mean_log_y_treated": float(log_y_treated.mean()), "mean_log_y_control": float(log_y_control.mean()), "diff": float(log_y_treated.mean() - log_y_control.mean()), "n_treated_conv": int(treated_conv.sum()), "n_control_conv": int(control_conv.sum()), }
# --------------------------------------------------------------------------- # Summary scorecard # --------------------------------------------------------------------------- def _split_chains(trace: np.ndarray, num_chains: int) -> list[np.ndarray]: """Split a concatenated trace into per-chain arrays. Samples are stored as chain-0-all, chain-1-all, ... (not interleaved). """ if num_chains <= 1: return [trace] s_per_chain = len(trace) // num_chains return [trace[i * s_per_chain:(i + 1) * s_per_chain] for i in range(num_chains)]
[docs] def compute_chain_diagnostics(diag: BCFDiagnosticData) -> dict[str, Any]: """Per-chain convergence diagnostics for multi-chain runs. Returns per-chain ESS, R-hat, autocorrelation at lag-1, and per-chain σ² means for visual comparison. """ nc = diag.num_chains chains = _split_chains(diag.sigma2_samples, nc) per_chain_ess = [compute_ess(c) for c in chains] per_chain_mean = [float(c.mean()) for c in chains] per_chain_std = [float(c.std()) for c in chains] # Autocorrelation at lag-1 per chain per_chain_acf1 = [] for c in chains: x = c - c.mean() var = float(np.var(x)) if var < 1e-15 or len(c) < 2: per_chain_acf1.append(0.0) else: acf1 = float(np.sum(x[:-1] * x[1:]) / (var * len(c))) per_chain_acf1.append(acf1) rhat = compute_rhat(chains) if nc >= 2 else float("nan") return { "num_chains": nc, "rhat_sigma2": rhat, "per_chain_ess": per_chain_ess, "per_chain_mean": per_chain_mean, "per_chain_std": per_chain_std, "per_chain_acf1": per_chain_acf1, "pooled_ess": compute_ess(diag.sigma2_samples), "samples_per_chain": len(chains[0]), }
[docs] def compute_scorecard( diag: BCFDiagnosticData, *, coverage: dict[str, float] | None = None, miscalibration_area: float | None = None, chain_diagnostics: dict[str, Any] | None = None, est_mean: np.ndarray | None = None, sorted_rpv: np.ndarray | None = None, ) -> dict[str, dict[str, Any]]: """Compute the summary scorecard with traffic-light grading. Returns a dict of metric_name → {value, grade, threshold_info}. Grade is "green", "yellow", or "red". Accepts pre-computed values to avoid redundant computation when called from render_report(). """ scorecard: dict[str, dict[str, Any]] = {} chain_diag = chain_diagnostics or compute_chain_diagnostics(diag) nc = chain_diag["num_chains"] # σ² convergence s2_trace = compute_sigma2_trace(diag) ratio = s2_trace["split_half_ratio"] if 0.9 <= ratio <= 1.1: grade = "green" elif 0.8 <= ratio <= 1.2: grade = "yellow" else: grade = "red" scorecard["sigma2_split_half"] = { "value": ratio, "grade": grade, "label": "σ² split-half ratio", } # R-hat (multi-chain only) if nc >= 2: rhat = chain_diag["rhat_sigma2"] if rhat < 1.05: grade = "green" elif rhat < 1.10: grade = "yellow" else: grade = "red" scorecard["rhat_sigma2"] = { "value": rhat, "grade": grade, "label": "σ² R-hat", "detail": f"{nc} chains", } # Per-chain ESS — report the minimum across chains (weakest link) per_ess = chain_diag["per_chain_ess"] s_per_chain = chain_diag["samples_per_chain"] min_ess = min(per_ess) min_ess_frac = min_ess / s_per_chain if s_per_chain > 0 else 0.0 if min_ess_frac > 0.3: grade = "green" elif min_ess_frac > 0.1: grade = "yellow" else: grade = "red" ess_detail = ( f"min of {nc} chains, {min_ess_frac:.0%} of {s_per_chain}" if nc > 1 else f"{min_ess_frac:.0%} of {s_per_chain} samples" ) scorecard["sigma2_ess"] = { "value": min_ess, "grade": grade, "label": "σ² ESS (min chain)", "detail": ess_detail, } # Autocorrelation lag-1 — report max across chains (worst mixing) max_acf1 = max(chain_diag["per_chain_acf1"]) if max_acf1 < 0.5: grade = "green" elif max_acf1 < 0.8: grade = "yellow" else: grade = "red" scorecard["sigma2_acf1"] = { "value": max_acf1, "grade": grade, "label": "σ² lag-1 autocorr", "detail": "max across chains" if nc > 1 else "", } # Calibration metrics (only with ground truth) if diag.true_rpv_cate is not None: # 95% coverage — use pre-computed if available if coverage is not None: cov = coverage else: cov = compute_coverage( diag.rpv_cate_samples, diag.true_rpv_cate, sorted_samples=sorted_rpv, ) cov_95 = cov.get("coverage_95", np.nan) if 0.90 <= cov_95 <= 0.98: grade = "green" elif 0.85 <= cov_95 <= 1.0: grade = "yellow" else: grade = "red" scorecard["coverage_95"] = { "value": cov_95, "grade": grade, "label": "95% coverage", } # Rank correlation — use pre-computed est_mean if available _est_mean = est_mean if est_mean is not None else diag.rpv_cate_samples.mean(axis=1) rho = spearmanr(_est_mean, diag.true_rpv_cate).statistic # pyright: ignore[reportAttributeAccessIssue] — scipy stub gap rho_val = float(rho) if not np.isnan(rho) else 0.0 if rho_val > 0.80: grade = "green" elif rho_val > 0.60: grade = "yellow" else: grade = "red" scorecard["rank_corr"] = { "value": rho_val, "grade": grade, "label": "Rank correlation", } # Sign accuracy sign_acc = float( (np.sign(_est_mean) == np.sign(diag.true_rpv_cate)).mean() ) if sign_acc > 0.80: grade = "green" elif sign_acc > 0.65: grade = "yellow" else: grade = "red" scorecard["sign_accuracy"] = { "value": sign_acc, "grade": grade, "label": "Sign accuracy", } # Miscalibration area — use pre-computed if available if miscalibration_area is not None: misc_area = miscalibration_area else: misc_area = compute_miscalibration_area( diag.rpv_cate_samples, diag.true_rpv_cate, sorted_samples=sorted_rpv, ) if misc_area < 0.05: grade = "green" elif misc_area < 0.15: grade = "yellow" else: grade = "red" scorecard["miscalibration_area"] = { "value": misc_area, "grade": grade, "label": "Miscalibration area", } # RMSE rmse = float(np.sqrt(np.mean((_est_mean - diag.true_rpv_cate) ** 2))) scorecard["rmse"] = { "value": rmse, "grade": "dim", "label": "CATE RMSE", } # Bias bias = float((_est_mean - diag.true_rpv_cate).mean()) scorecard["bias"] = { "value": bias, "grade": "dim", "label": "Mean bias", } return scorecard
# --------------------------------------------------------------------------- # Slim diagnostics (fast path for sweep) # ---------------------------------------------------------------------------
[docs] def compute_slim_diagnostics(diag: BCFDiagnosticData) -> dict[str, Any]: """Fast diagnostic summary: pre-sort once, compute key metrics. Returns a flat dict suitable for writing to summary.json in sweep mode. Includes coverage at 7 levels, quintile breakdown, PPC scalars, and channel calibration when available. """ if diag.true_rpv_cate is None: return {"error": "no ground truth available"} sorted_rpv = presort_samples(diag.rpv_cate_samples) est_mean = diag.rpv_cate_samples.mean(axis=1) truth = diag.true_rpv_cate # Coverage at expanded levels (presort fast-path makes this free) cov = compute_coverage( diag.rpv_cate_samples, truth, levels=(0.05, 0.10, 0.20, 0.50, 0.80, 0.90, 0.95, 0.975, 0.99, 0.995, 0.999, 0.9995), sorted_samples=sorted_rpv, ) # Rank correlation rho = spearmanr(est_mean, truth).statistic # pyright: ignore[reportAttributeAccessIssue] — scipy stub gap rank_rho = float(rho) if not np.isnan(rho) else 0.0 # Sign accuracy sign_acc = float((np.sign(est_mean) == np.sign(truth)).mean()) # RMSE + bias rmse = float(np.sqrt(np.mean((est_mean - truth) ** 2))) bias = float((est_mean - truth).mean()) # Miscalibration area miscal = compute_miscalibration_area( diag.rpv_cate_samples, truth, sorted_samples=sorted_rpv, ) # σ² convergence s2_trace = compute_sigma2_trace(diag) # Heterogeneity detection: across-visitor spread vs per-visitor uncertainty est_spread = float(est_mean.std()) widths = diag.rpv_cate_samples.std(axis=1) # (n,) — reused for width distribution below post_width = float(widths.mean()) hte_snr = est_spread / post_width if post_width > 1e-10 else 0.0 true_spread = float(truth.std()) # dict[str, Any] since some entries fall back to "" (sentinel for missing # bins) when the corresponding data isn't present — see ptau_bin* + # sign_boundary_coverage_90 paths below. result: dict[str, Any] = { **cov, "rank_rho": rank_rho, "sign_accuracy": sign_acc, "rmse": rmse, "bias": bias, "miscalibration_area": miscal, "sigma2_mean": s2_trace["mean"], "sigma2_split_half_ratio": s2_trace["split_half_ratio"], "est_spread": est_spread, "post_width": post_width, "hte_snr": hte_snr, "true_spread": true_spread, } # Quintile breakdown — captures tail-vs-center miscalibration shape quintiles = compute_quintile_calibration( diag, sorted_samples=sorted_rpv, est_mean=est_mean, ) if quintiles is not None: # Pre-compute quintile assignments for per-quintile RMSE ranks = np.argsort(np.argsort(est_mean)) n = len(est_mean) q_ids = np.minimum(ranks * 5 // n, 4) for qd in quintiles: q = qd["quintile"] mask = q_ids == (q - 1) q_rmse = float(np.sqrt(np.mean((est_mean[mask] - truth[mask]) ** 2))) if mask.sum() > 0 else 0.0 result[f"q{q}_coverage_90"] = qd["coverage_90"] result[f"q{q}_sign_accuracy"] = qd["sign_accuracy"] result[f"q{q}_rmse"] = q_rmse result[f"q{q}_bias"] = qd["est_mean"] - qd["true_mean"] result[f"q{q}_est_mean"] = qd["est_mean"] result[f"q{q}_true_mean"] = qd["true_mean"] # --- SBC rank KS statistic --- # Rank of truth within sorted posterior: if posteriors are calibrated, # ranks should be Uniform(0, S). KS test detects systematic over/under-coverage. n_samples = sorted_rpv.shape[1] ranks = (sorted_rpv < truth[:, np.newaxis]).sum(axis=1) # (n,) ks_stat, ks_pval = kstest(ranks, "uniform", args=(0, n_samples)) result["rank_ks_stat"] = float(ks_stat) result["rank_ks_pval"] = float(ks_pval) # --- P(τ>0) calibration curve (binned shape) --- p_positive = (diag.rpv_cate_samples > 0).mean(axis=1) truly_positive = (truth > 0).astype(float) ptau_cal = compute_calibration_curve(p_positive, truly_positive, n_bins=10) # Emit per-bin metrics for the isotonic fitting step for k in range(10): if k < len(ptau_cal["predicted"]): result[f"ptau_bin{k}_predicted"] = ptau_cal["predicted"][k] result[f"ptau_bin{k}_actual"] = ptau_cal["actual"][k] result[f"ptau_bin{k}_n"] = ptau_cal["sizes"][k] else: result[f"ptau_bin{k}_predicted"] = "" result[f"ptau_bin{k}_actual"] = "" result[f"ptau_bin{k}_n"] = "" # Summary scalars if ptau_cal["predicted"]: pred_arr = np.array(ptau_cal["predicted"]) act_arr = np.array(ptau_cal["actual"]) result["ptau_calibration_mae"] = float(np.abs(pred_arr - act_arr).mean()) result["ptau_calibration_bias"] = float((act_arr - pred_arr).mean()) else: result["ptau_calibration_mae"] = "" result["ptau_calibration_bias"] = "" # --- Sign boundary coverage --- # Near-zero band: |true τ| < median(|true τ|) / 2 — adaptive to effect scale abs_truth = np.abs(truth) threshold = float(np.median(abs_truth)) / 2.0 near_zero_mask = abs_truth < threshold sign_boundary_n = int(near_zero_mask.sum()) result["sign_boundary_n"] = sign_boundary_n if sign_boundary_n >= 20: sorted_boundary = sorted_rpv[near_zero_mask] q_map = quantiles_from_sorted(sorted_boundary, [0.05, 0.95]) lo = q_map[0.05] hi = q_map[0.95] boundary_truth = truth[near_zero_mask] result["sign_boundary_coverage_90"] = float( ((boundary_truth >= lo) & (boundary_truth <= hi)).mean() ) else: result["sign_boundary_coverage_90"] = "" # --- Per-channel individual-level coverage --- chan_levels = (0.50, 0.80, 0.90, 0.95) for label, samples, truth_arr in [ ("p0", diag.p0_samples, diag.true_p0), ("p1", diag.p1_samples, diag.true_p1), ("sev0", diag.sev0_samples, diag.true_m0), ("sev1", diag.sev1_samples, diag.true_m1), ]: if samples is not None and truth_arr is not None and samples.shape[1] > 0: sorted_chan = presort_samples(samples) chan_cov = compute_coverage( samples, truth_arr, levels=chan_levels, sorted_samples=sorted_chan, ) for key, val in chan_cov.items(): result[f"chan_{label}_{key}"] = val chan_rmse = float(np.sqrt(np.mean((samples.mean(axis=1) - truth_arr) ** 2))) result[f"chan_{label}_rmse"] = chan_rmse # --- Interval width distribution --- # widths already computed above for post_width/hte_snr result["post_width_q25"] = float(np.quantile(widths, 0.25)) result["post_width_q75"] = float(np.quantile(widths, 0.75)) width_mean = widths.mean() result["post_width_cv"] = float(widths.std() / width_mean) if width_mean > 1e-10 else 0.0 # PPC summary passthrough (GPU path stores PPC in gpu_diagnostics) if diag.gpu_diagnostics is not None and "ppc" in diag.gpu_diagnostics: ppc = diag.gpu_diagnostics["ppc"] result["ppc_binary_mae"] = ppc.get("binary_mae", float("nan")) result["ppc_continuous_mae"] = ppc.get("continuous_mae", float("nan")) result["ppc_n_converters"] = ppc.get("n_converters", 0) # Channel calibration (available when full sample-level p0/p1 exist) chan_cal = compute_channel_calibration(diag) if chan_cal is not None: if "binary" in chan_cal: result["binary_coverage_90"] = chan_cal["binary"].get("coverage_90", float("nan")) result["binary_coverage_95"] = chan_cal["binary"].get("coverage_95", float("nan")) if "composed" in chan_cal: result["composed_coverage_90"] = chan_cal["composed"].get("coverage_90", float("nan")) result["composed_coverage_95"] = chan_cal["composed"].get("coverage_95", float("nan")) # Channel attribution (available when true_p0/p1/m0/m1 are provided) chan_attr = compute_channel_attribution(diag) if chan_attr is not None: for k, v in chan_attr.items(): result[k] = v # Topology mobility metrics — surfaced only when the producing fit # retained per-iter topology hashes via GPUBCFConfig.retain_topology_history. if diag.topology_history is not None: from pytyche.bcf.diagnostics.topology import compute_topology_metrics # noqa: PLC0415 topology_metrics = compute_topology_metrics(diag.topology_history) result.update(topology_metrics) return result