Source code for pytyche.analysis._truth

"""Sim-mode evaluation of a fitted policy against ground truth.

Implementation behind ``posterior.evaluate_against_truth(tree, truth)``.
Only the simulation/calibration path has a :class:`~pytyche.contracts.CalibrationTruth`;
real-data mode (``truth is None``) raises ``RuntimeError`` — there is no
ground truth to evaluate against.
"""

from __future__ import annotations

import dataclasses

import numpy as np

from pytyche._internal.extraction import extract_fit_arrays
from pytyche.analysis._contrasts import (
    AnyBCFResult,
    best_arms,
    contrast_samples,
    require_observed,
)
from pytyche.analysis._policy_tree import PolicyTreeResult
from pytyche.contracts import CalibrationTruth


[docs] @dataclasses.dataclass(frozen=True) class TruthComparison: """Sim-mode evaluation metrics. Tight — no optional fields. Produced by ``posterior.evaluate_against_truth(tree, truth)``. Level: experiment (simulation only). Fields: cate_rmse: RMSE of the posterior-mean contrasts against the truth contrasts, pooled over all ``n × (K − 1)`` (visitor, contrast) entries. policy_accuracy: Fraction of visitors whose tree-assigned arm equals the truth best arm (the shared best-arm rule applied to the truth contrast vectors). rpv_policy: Realized revenue-per-visitor of the tree policy under the truth potential-outcome surfaces (``p_a · m_a`` per arm). rpv_uniform: Realized RPV of uniform random assignment. rpv_oracle: Realized RPV of the per-visitor best arm. oracle_gap_rpv: ``rpv_oracle − rpv_policy`` — how much the policy leaves on the table against a truth-knowing oracle. """ cate_rmse: float policy_accuracy: float rpv_policy: float rpv_uniform: float rpv_oracle: float oracle_gap_rpv: float def __repr__(self) -> str: return ( "TruthComparison\n" f" cate_rmse: {self.cate_rmse:.4f}" f" policy_accuracy: {self.policy_accuracy:.0%}\n" f" rpv — policy: {self.rpv_policy:.4f}" f" uniform: {self.rpv_uniform:.4f}" f" oracle: {self.rpv_oracle:.4f}" f" oracle gap: {self.oracle_gap_rpv:.4f}" )
def _truth_contrasts(truth: CalibrationTruth, k: int) -> np.ndarray: """Truth contrast matrix ``(n, K − 1)`` from the K-dispatched fields.""" if k == 2: if truth.cate_per_visitor is None: raise ValueError( "evaluate_against_truth: truth.cate_per_visitor is None — " "the K = 2 truth must carry per-visitor CATEs." ) return np.asarray(truth.cate_per_visitor.values, dtype=float)[:, None] if truth.contrast_cate_per_visitor is None: raise ValueError( "evaluate_against_truth: truth.contrast_cate_per_visitor is " "None — the K >= 3 truth must carry per-contrast CATEs." ) return np.stack( [ np.asarray(a.values, dtype=float) for a in truth.contrast_cate_per_visitor ], axis=1, ) def _truth_rpv_surfaces(truth: CalibrationTruth, k: int) -> np.ndarray: """Per-arm realized RPV ``(n, K)`` from the truth potential outcomes.""" if k == 2: p0, p1 = truth.p0_per_visitor, truth.p1_per_visitor m0, m1 = truth.m0_per_visitor, truth.m1_per_visitor if p0 is None or p1 is None or m0 is None or m1 is None: named = { "p0_per_visitor": p0, "p1_per_visitor": p1, "m0_per_visitor": m0, "m1_per_visitor": m1, } missing = [name for name, v in named.items() if v is None] raise ValueError( "evaluate_against_truth: the truth carries no potential-" f"outcome surfaces ({', '.join(missing)} missing) — the " "realized-RPV metrics need them." ) return np.stack( [ np.asarray(p0.values, dtype=float) * np.asarray(m0.values, dtype=float), np.asarray(p1.values, dtype=float) * np.asarray(m1.values, dtype=float), ], axis=1, ) if truth.p_per_visitor is None or truth.m_per_visitor is None: raise ValueError( "evaluate_against_truth: the K >= 3 truth carries no " "potential-outcome surfaces (p_per_visitor/m_per_visitor " "missing) — the realized-RPV metrics need them." ) return np.stack( [ np.asarray(p.values, dtype=float) * np.asarray(m.values, dtype=float) for p, m in zip(truth.p_per_visitor, truth.m_per_visitor, strict=True) ], axis=1, )
[docs] def evaluate_against_truth( posterior: AnyBCFResult, tree: PolicyTreeResult, truth: CalibrationTruth | None, ) -> TruthComparison: """Evaluate the posterior + tree policy against ground truth. Sim-mode only: generators pair observed data with a :class:`~pytyche.contracts.CalibrationTruth`; real-data mode has no truth and raises ``RuntimeError``. Args: posterior: One of the three posterior result types, carrying observed data (raises otherwise). tree: The fitted policy whose assignments are evaluated. truth: Ground truth for the simulated experiment, or ``None`` in real-data mode. Returns: :class:`TruthComparison` with all six metrics populated. Raises: RuntimeError: When *truth* is ``None`` (real-data mode — nothing to evaluate against). ValueError: When ``posterior.observed`` is ``None``, when the truth lacks the K-appropriate contrast or potential-outcome fields, or when the truth shapes do not match the posterior. """ if truth is None: raise RuntimeError( "evaluate_against_truth is sim-mode only: truth is None " "(real-data mode has no ground truth to evaluate against)." ) observed = require_observed(posterior) k = len(observed.variants) mean_contrasts = contrast_samples(posterior).mean(axis=1) # (n, K-1) n = mean_contrasts.shape[0] truth_contrasts = _truth_contrasts(truth, k) if truth_contrasts.shape != mean_contrasts.shape: raise ValueError( f"evaluate_against_truth: truth contrasts have shape " f"{truth_contrasts.shape}, posterior contrasts " f"{mean_contrasts.shape} — misaligned truth." ) cate_rmse = float( np.sqrt(np.mean((mean_contrasts - truth_contrasts) ** 2)) ) # Tree policy: the multiclass labels ARE arm ids (0 = control). arrays = extract_fit_arrays(observed) assigned = np.asarray(tree.tree.predict(arrays.X), dtype=int) truth_best = best_arms(truth_contrasts) policy_accuracy = float((assigned == truth_best).mean()) rpv = _truth_rpv_surfaces(truth, k) # (n, K) rpv_policy = float(rpv[np.arange(n), assigned].mean()) rpv_uniform = float(rpv.mean()) rpv_oracle = float(rpv.max(axis=1).mean()) return TruthComparison( cate_rmse=cate_rmse, policy_accuracy=policy_accuracy, rpv_policy=rpv_policy, rpv_uniform=rpv_uniform, rpv_oracle=rpv_oracle, oracle_gap_rpv=rpv_oracle - rpv_policy, )