"""Sim-mode evaluation of a fitted policy against ground truth.
Implementation behind ``posterior.evaluate_against_truth(tree, truth)``.
Only the simulation/calibration path has a :class:`~pytyche.contracts.CalibrationTruth`;
real-data mode (``truth is None``) raises ``RuntimeError`` — there is no
ground truth to evaluate against.
"""
from __future__ import annotations
import dataclasses
import numpy as np
from pytyche._internal.extraction import extract_fit_arrays
from pytyche.analysis._contrasts import (
AnyBCFResult,
best_arms,
contrast_samples,
require_observed,
)
from pytyche.analysis._policy_tree import PolicyTreeResult
from pytyche.contracts import CalibrationTruth
[docs]
@dataclasses.dataclass(frozen=True)
class TruthComparison:
"""Sim-mode evaluation metrics. Tight — no optional fields.
Produced by ``posterior.evaluate_against_truth(tree, truth)``.
Level: experiment (simulation only).
Fields:
cate_rmse: RMSE of the posterior-mean contrasts against the truth
contrasts, pooled over all ``n × (K − 1)`` (visitor, contrast)
entries.
policy_accuracy: Fraction of visitors whose tree-assigned arm
equals the truth best arm (the shared best-arm rule applied
to the truth contrast vectors).
rpv_policy: Realized revenue-per-visitor of the tree policy under
the truth potential-outcome surfaces (``p_a · m_a`` per arm).
rpv_uniform: Realized RPV of uniform random assignment.
rpv_oracle: Realized RPV of the per-visitor best arm.
oracle_gap_rpv: ``rpv_oracle − rpv_policy`` — how much the policy
leaves on the table against a truth-knowing oracle.
"""
cate_rmse: float
policy_accuracy: float
rpv_policy: float
rpv_uniform: float
rpv_oracle: float
oracle_gap_rpv: float
def __repr__(self) -> str:
return (
"TruthComparison\n"
f" cate_rmse: {self.cate_rmse:.4f}"
f" policy_accuracy: {self.policy_accuracy:.0%}\n"
f" rpv — policy: {self.rpv_policy:.4f}"
f" uniform: {self.rpv_uniform:.4f}"
f" oracle: {self.rpv_oracle:.4f}"
f" oracle gap: {self.oracle_gap_rpv:.4f}"
)
def _truth_contrasts(truth: CalibrationTruth, k: int) -> np.ndarray:
"""Truth contrast matrix ``(n, K − 1)`` from the K-dispatched fields."""
if k == 2:
if truth.cate_per_visitor is None:
raise ValueError(
"evaluate_against_truth: truth.cate_per_visitor is None — "
"the K = 2 truth must carry per-visitor CATEs."
)
return np.asarray(truth.cate_per_visitor.values, dtype=float)[:, None]
if truth.contrast_cate_per_visitor is None:
raise ValueError(
"evaluate_against_truth: truth.contrast_cate_per_visitor is "
"None — the K >= 3 truth must carry per-contrast CATEs."
)
return np.stack(
[
np.asarray(a.values, dtype=float)
for a in truth.contrast_cate_per_visitor
],
axis=1,
)
def _truth_rpv_surfaces(truth: CalibrationTruth, k: int) -> np.ndarray:
"""Per-arm realized RPV ``(n, K)`` from the truth potential outcomes."""
if k == 2:
p0, p1 = truth.p0_per_visitor, truth.p1_per_visitor
m0, m1 = truth.m0_per_visitor, truth.m1_per_visitor
if p0 is None or p1 is None or m0 is None or m1 is None:
named = {
"p0_per_visitor": p0,
"p1_per_visitor": p1,
"m0_per_visitor": m0,
"m1_per_visitor": m1,
}
missing = [name for name, v in named.items() if v is None]
raise ValueError(
"evaluate_against_truth: the truth carries no potential-"
f"outcome surfaces ({', '.join(missing)} missing) — the "
"realized-RPV metrics need them."
)
return np.stack(
[
np.asarray(p0.values, dtype=float)
* np.asarray(m0.values, dtype=float),
np.asarray(p1.values, dtype=float)
* np.asarray(m1.values, dtype=float),
],
axis=1,
)
if truth.p_per_visitor is None or truth.m_per_visitor is None:
raise ValueError(
"evaluate_against_truth: the K >= 3 truth carries no "
"potential-outcome surfaces (p_per_visitor/m_per_visitor "
"missing) — the realized-RPV metrics need them."
)
return np.stack(
[
np.asarray(p.values, dtype=float) * np.asarray(m.values, dtype=float)
for p, m in zip(truth.p_per_visitor, truth.m_per_visitor, strict=True)
],
axis=1,
)
[docs]
def evaluate_against_truth(
posterior: AnyBCFResult,
tree: PolicyTreeResult,
truth: CalibrationTruth | None,
) -> TruthComparison:
"""Evaluate the posterior + tree policy against ground truth.
Sim-mode only: generators pair observed data with a
:class:`~pytyche.contracts.CalibrationTruth`; real-data mode has no
truth and raises ``RuntimeError``.
Args:
posterior: One of the three posterior result types, carrying
observed data (raises otherwise).
tree: The fitted policy whose assignments are evaluated.
truth: Ground truth for the simulated experiment, or ``None`` in
real-data mode.
Returns:
:class:`TruthComparison` with all six metrics populated.
Raises:
RuntimeError: When *truth* is ``None`` (real-data mode — nothing
to evaluate against).
ValueError: When ``posterior.observed`` is ``None``, when the
truth lacks the K-appropriate contrast or potential-outcome
fields, or when the truth shapes do not match the posterior.
"""
if truth is None:
raise RuntimeError(
"evaluate_against_truth is sim-mode only: truth is None "
"(real-data mode has no ground truth to evaluate against)."
)
observed = require_observed(posterior)
k = len(observed.variants)
mean_contrasts = contrast_samples(posterior).mean(axis=1) # (n, K-1)
n = mean_contrasts.shape[0]
truth_contrasts = _truth_contrasts(truth, k)
if truth_contrasts.shape != mean_contrasts.shape:
raise ValueError(
f"evaluate_against_truth: truth contrasts have shape "
f"{truth_contrasts.shape}, posterior contrasts "
f"{mean_contrasts.shape} — misaligned truth."
)
cate_rmse = float(
np.sqrt(np.mean((mean_contrasts - truth_contrasts) ** 2))
)
# Tree policy: the multiclass labels ARE arm ids (0 = control).
arrays = extract_fit_arrays(observed)
assigned = np.asarray(tree.tree.predict(arrays.X), dtype=int)
truth_best = best_arms(truth_contrasts)
policy_accuracy = float((assigned == truth_best).mean())
rpv = _truth_rpv_surfaces(truth, k) # (n, K)
rpv_policy = float(rpv[np.arange(n), assigned].mean())
rpv_uniform = float(rpv.mean())
rpv_oracle = float(rpv.max(axis=1).mean())
return TruthComparison(
cate_rmse=cate_rmse,
policy_accuracy=policy_accuracy,
rpv_policy=rpv_policy,
rpv_uniform=rpv_uniform,
rpv_oracle=rpv_oracle,
oracle_gap_rpv=rpv_oracle - rpv_policy,
)