Source code for pytyche.calibrate.sbc

"""v2 calibration evaluator.

OracleConfig, oracle decision logic, and the ``calibrate`` function
that bridges analysis output to calibration records.

``calibrate`` consumes the v0.2 ``AnalysisResult`` SUMMARY surface:
lean ``Comparison`` entries (point estimates and probabilities, no
sample arrays) plus the ``RecommendationSummary`` decision-theoretic
snapshot.  Anything needing posterior samples goes through the
posterior itself, not through ``calibrate``.
"""

from __future__ import annotations

import dataclasses

from pytyche.contracts import (
    AnalysisResult,
    CalibrationRecord,
    CalibrationTruth,
    ClaimLevel,
    Decision,
)
from pytyche.metrics import get_metric_config


[docs] @dataclasses.dataclass(frozen=True) class OracleConfig: """Oracle decision thresholds for calibration evaluation. Fields: ship_threshold: effect > ship_threshold → SHIP. null_epsilon: effect < -null_epsilon → STOP. Both must be non-negative. No ordering constraint between them. """ ship_threshold: float null_epsilon: float def __post_init__(self) -> None: if self.ship_threshold < 0: raise ValueError( f"OracleConfig: ship_threshold must be non-negative, got {self.ship_threshold}" ) if self.null_epsilon < 0: raise ValueError( f"OracleConfig: null_epsilon must be non-negative, got {self.null_epsilon}" )
def _oracle_decision(effect: float, config: OracleConfig) -> Decision: """Return the oracle decision for a given effect and config.""" if effect > config.ship_threshold: return Decision.SHIP if effect < -config.null_epsilon: return Decision.STOP return Decision.CONTINUE def _decision_correct(oracle: Decision, actual: Decision) -> bool: """Return whether the actual decision is correct given the oracle decision. Ship-is-irreversible semantics: - oracle SHIP → only SHIP is correct - oracle STOP or CONTINUE → any non-SHIP decision is correct Args: oracle: What the oracle says should happen (ground-truth decision). actual: The decision the analysis actually recommended. """ if oracle == Decision.SHIP: return actual == Decision.SHIP return actual != Decision.SHIP def _regret(effect: float, oracle: Decision, actual: Decision) -> float: """Return regret for a decision given the true effect and oracle decision. Regret is always non-negative and in metric-native units. - Correct decision → 0.0 - False ship (shipped when shouldn't) → max(-effect, 0.0) - Missed win (didn't ship when should) → max(effect, 0.0) Args: effect: True treatment effect in metric-native units. oracle: What the oracle says should have happened. actual: The decision that was actually made. """ if _decision_correct(oracle, actual): return 0.0 if actual == Decision.SHIP: # False ship: shipped when shouldn't have return max(-effect, 0.0) # Missed win: only reachable when oracle == SHIP and actual != SHIP. if oracle != Decision.SHIP: raise RuntimeError( f"_regret invariant violated: incorrect non-SHIP decision " f"with oracle={oracle!r}, actual={actual!r}" ) return max(effect, 0.0)
[docs] def calibrate( result: AnalysisResult, truth: CalibrationTruth, oracle: OracleConfig, scenario_id: str, seed: int, ) -> CalibrationRecord: """Evaluate one analysis result against ground truth, returning a CalibrationRecord. Pure function: no side effects, no MCMC, no data generation. Consumes the v0.2 ``AnalysisResult`` SUMMARY surface — point estimates and probabilities only, no posterior sample arrays. Args: result: Summary analysis output from ``posterior.analyze()``. truth: Ground truth for the simulated experiment. oracle: Decision thresholds for correctness evaluation. scenario_id: Identifier for the simulation scenario. seed: Random seed for this run. Returns: A ``CalibrationRecord`` with all evaluation fields populated. Raises: ValueError: If any input validation check fails (fail-closed). """ # D7: Fail-closed input validation — all checks before any computation. # Check 1: single comparison (2-arm contract) if len(result.comparisons) != 1: raise ValueError( f"calibrate requires a single comparison (2-arm contract), " f"got {len(result.comparisons)}" ) # Check 2: metric match if result.metric != truth.metric_id: raise ValueError( f"calibrate: metric mismatch — result.metric={result.metric!r} " f"does not match truth.metric_id={truth.metric_id!r}" ) comparison = result.comparisons[0] # No CI-ordering check: an inverted-CI Comparison is unconstructible — # Comparison.__post_init__ validates ordering at the type level. # Check 3: scenario_id must be a non-empty string if not isinstance(scenario_id, str) or not scenario_id.strip(): raise ValueError( f"calibrate: scenario_id must be a non-empty, non-whitespace string, " f"got {scenario_id!r}" ) # Check 4: integer seed (bool excluded — bool is a subclass of int) if not isinstance(seed, int) or isinstance(seed, bool): raise ValueError( f"calibrate: seed must be an int, got {type(seed).__name__!r} ({seed!r})" ) # D6: Field mapping oracle_decision = _oracle_decision(truth.effect, oracle) actual_decision = result.recommendation.decision return CalibrationRecord( scenario_id=scenario_id, seed=seed, analysis_mode=ClaimLevel.HONEST_ESTIMATE, effect=truth.effect, metric_id=truth.metric_id, metric_family=truth.metric_family, effect_components=truth.effect_components, estimator_id=get_metric_config(result.metric).model_id, estimated_lift=comparison.lift_estimate, ci_low=comparison.lift_ci[0], ci_high=comparison.lift_ci[1], # Pinned credible-interval convention: Comparison.lift_ci is always # the 80% interval (no per-comparison ci_level on the v0.2 surface). ci_level=0.80, probability_positive=comparison.probability_positive, probability_better=result.recommendation.probability_better, probability_harmful=result.recommendation.probability_harmful, expected_loss_baseline=result.recommendation.expected_loss_baseline, expected_loss_comparison=result.recommendation.expected_loss_comparison, decision=actual_decision, oracle_decision=oracle_decision, decision_correct=_decision_correct(oracle_decision, actual_decision), regret=_regret(truth.effect, oracle_decision, actual_decision), )