Source code for pytyche.calibrate.layered

"""Layered calibration deployment primitive — R(p) below ceiling, scale-family above.

The deployment path does NOT read the fitted ``c`` parameter; ``c`` is
persisted only for the matplotlib diagnostic plot. The past-ceiling regime
anchors on the SBC-measured pair ``(max_recorded_nominal, max_recorded_actual)``
and scales per-test-point empirical anchor quantiles by the Gaussian z-score
ratio. This preserves the per-point asymmetry of hurdle CATEs that a
fully-Gaussian-σ inflation would average away.
"""

from __future__ import annotations

import json
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Protocol

import jax.numpy as jnp
import numpy as np
import scipy.stats

from pytyche.calibrate._correction_io import invert_correction


class _HasCateDraws(Protocol):
    """Structural type for the result argument of :func:`apply_layered_calibration`.

    Both ``HurdleBCFResult`` and ad-hoc test/eval stubs (``SimpleNamespace``)
    that carry the posterior CATE draws satisfy this protocol.
    """

    rpv_cate_samples: np.ndarray


_COVERAGE_FILENAME = "coverage_correction.json"
_SCALE_FAMILY_FILENAME = "scale_family_correction.json"


[docs] @dataclass(frozen=True) class CoverageCorrection: """Frozen view of a serialized R(p) isotonic correction. Mirrors the ``coverage_correction.json`` schema produced by ``scripts/fit_sbc_correction.py``: a pair of monotone arrays defining the isotonic step function ``nominal → actual``. """ x_thresholds: tuple[float, ...] y_values: tuple[float, ...]
[docs] @dataclass(frozen=True) class ScaleFamilyCorrection: """Frozen view of a fitted scale-family correction. ``c`` is the Gaussian-scale LSQ fit, persisted for diagnostic-plot rendering only. :func:`apply_layered_calibration` deploys against the empirical anchor ``(max_recorded_nominal, max_recorded_actual)``. """ c: float max_recorded_nominal: float max_recorded_actual: float fit_method: str n_points_fit: int
[docs] @dataclass(frozen=True) class CalibratedIntervals: """Per-test-point calibrated intervals + provenance. Returned by :func:`apply_layered_calibration`. ``method`` records which regime fired (``"r_of_p"`` or ``"scale_family"``); ``applied_corrections`` carries regime-specific provenance keys. """ lower: np.ndarray upper: np.ndarray desired_coverage: float achieved_actual_coverage: float | None method: Literal["r_of_p", "scale_family"] applied_corrections: dict
[docs] @dataclass(frozen=True) class LayeredCalibrationCorrection: """Discriminated container for a fitted layered SBC correction. The ``mode`` field selects which deployment regimes are supported: - ``"layered"``: both R(p) (below ceiling) and scale-family (past ceiling) regimes available. ``scale_family`` MUST be non-None. - ``"r_of_p_only"``: only R(p) is available. Past-ceiling deployment raises rather than silently degrading. ``scale_family`` MUST be None. The construction-time invariant is validated in ``__post_init__`` and at load time by :meth:`from_directory`. No silent fallback: a caller wanting r_of_p_only semantics on a sweep that DOES have a scale-family file must pass ``mode="r_of_p_only"`` explicitly and will see a ``UserWarning`` noting the ignored file. """ mode: Literal["layered", "r_of_p_only"] coverage_correction: CoverageCorrection scale_family: ScaleFamilyCorrection | None def __post_init__(self) -> None: """Validate the ``mode`` / ``scale_family`` consistency invariant. ``scale_family`` MUST be non-None iff ``mode == "layered"``. Both directions of inconsistency raise ``ValueError`` naming the offending fields (per spec "construction rejects mode/scale_family inconsistency"). """ if self.mode == "layered" and self.scale_family is None: raise ValueError( "Inconsistent mode / scale_family: mode='layered' requires " "a non-None scale_family, but scale_family is None." ) if self.mode == "r_of_p_only" and self.scale_family is not None: raise ValueError( "Inconsistent mode / scale_family: mode='r_of_p_only' " "requires scale_family to be None, but a non-None " "scale_family was provided." )
[docs] @classmethod def from_directory( cls, path: str | Path, *, mode: Literal["layered", "r_of_p_only"], ) -> LayeredCalibrationCorrection: """Load a correction pair from a directory. Parameters ---------- path Directory containing ``coverage_correction.json`` and, for layered mode, ``scale_family_correction.json``. mode Explicit deployment mode. No default; the caller must declare their intent at the call site rather than receive silent fallback. Returns ------- LayeredCalibrationCorrection Loaded correction with typed inner dataclasses. Raises ------ FileNotFoundError If a required JSON file is missing for the requested mode. The message names the missing file. Warns ----- UserWarning If ``mode == "r_of_p_only"`` and ``scale_family_correction.json`` is present in the directory — the file is being ignored due to the explicit mode choice (visible signal, not silent). """ directory = Path(path) coverage_path = directory / _COVERAGE_FILENAME scale_family_path = directory / _SCALE_FAMILY_FILENAME if not coverage_path.exists(): raise FileNotFoundError( f"Required calibration artifact not found: {coverage_path} " f"({_COVERAGE_FILENAME} is required for both modes)." ) coverage_dict = json.loads(coverage_path.read_text()) coverage = CoverageCorrection( x_thresholds=tuple(float(v) for v in coverage_dict["x_thresholds"]), y_values=tuple(float(v) for v in coverage_dict["y_values"]), ) if mode == "layered": if not scale_family_path.exists(): raise FileNotFoundError( f"Required calibration artifact not found: " f"{scale_family_path} ({_SCALE_FAMILY_FILENAME} is " f"required when mode='layered')." ) sf_dict = json.loads(scale_family_path.read_text()) scale_family = ScaleFamilyCorrection( c=float(sf_dict["c"]), max_recorded_nominal=float(sf_dict["max_recorded_nominal"]), max_recorded_actual=float(sf_dict["max_recorded_actual"]), fit_method=str(sf_dict["fit_method"]), n_points_fit=int(sf_dict["n_points_fit"]), ) return cls( mode="layered", coverage_correction=coverage, scale_family=scale_family, ) # mode == "r_of_p_only" if scale_family_path.exists(): warnings.warn( f"{_SCALE_FAMILY_FILENAME} is present in {directory} but is " f"being ignored because mode='r_of_p_only' was requested " f"explicitly. Past-ceiling deployment will not be available " f"from this correction.", UserWarning, stacklevel=2, ) return cls( mode="r_of_p_only", coverage_correction=coverage, scale_family=None, )
[docs] def apply_layered_calibration( result: _HasCateDraws, correction: LayeredCalibrationCorrection, desired_coverage: float, ) -> CalibratedIntervals: """Apply a fitted layered calibration to posterior CATE draws. Stateless: does not mutate ``result``. Reads only ``result.rpv_cate_samples`` (shape ``(n, S)``) — the per-test-point posterior draws produced by a joint hurdle BCF fit. Regime selection (hard boundary, no smoothing — see spec "hard boundary switch with no smoothing"): - ``desired_coverage <= ceiling`` → R(p) regime: invert the isotonic curve to find nominal ``n*`` such that ``R(n*) == desired_coverage``, then return the symmetric ``(1-n*)/2`` / ``1-(1-n*)/2`` quantile pair on the per-test-point draws. - ``desired_coverage > ceiling`` → scale-family regime: anchor on the empirical ``(q_low, q_high)`` at ``max_recorded_nominal`` and inflate both half-widths around the per-point posterior mean by the Gaussian z-score ratio ``s``. This preserves per-point empirical asymmetry that a fully-Gaussian-σ inflation would average away. The ceiling is ``correction.scale_family.max_recorded_actual`` for layered mode and ``max(correction.coverage_correction.y_values)`` for r_of_p_only mode. Parameters ---------- result Fit result providing ``rpv_cate_samples: ndarray[(n, S)]``. correction Fitted layered calibration loaded via :meth:`LayeredCalibrationCorrection.from_directory`. desired_coverage Requested coverage level in the open interval ``(0, 1)``. Returns ------- CalibratedIntervals Calibrated per-test-point endpoints plus provenance. Raises ------ ValueError If ``desired_coverage`` is outside ``(0, 1)``. If ``correction.mode == "r_of_p_only"`` and ``desired_coverage`` exceeds the R(p) ceiling — past-ceiling calibration requires a scale-family correction. """ if not (0.0 < desired_coverage < 1.0): raise ValueError( f"desired_coverage must lie in the open interval (0, 1); " f"got {desired_coverage}." ) samples = result.rpv_cate_samples sf = correction.scale_family # invariant: non-None iff mode == "layered" if sf is not None: ceiling = sf.max_recorded_actual else: ceiling = float(max(correction.coverage_correction.y_values)) if desired_coverage <= ceiling: # R(p) regime: invert isotonic to find nominal n*, read quantiles. cov = correction.coverage_correction nominal_star = float( invert_correction( np.asarray(cov.x_thresholds), np.asarray(cov.y_values), desired_coverage, ).item() ) tail = (1.0 - nominal_star) / 2.0 lower, upper = jnp.quantile(samples, jnp.asarray([tail, 1.0 - tail]), axis=1) return CalibratedIntervals( lower=lower, upper=upper, desired_coverage=desired_coverage, achieved_actual_coverage=None, method="r_of_p", applied_corrections={"nominal_star": nominal_star}, ) # desired_coverage > ceiling → scale-family regime. if sf is None: raise ValueError( f"desired_coverage={desired_coverage} exceeds the R(p) ceiling " f"({ceiling}); past-ceiling calibration is not available in " f"mode='r_of_p_only'. Re-fit the SBC sweep to produce a " f"{_SCALE_FAMILY_FILENAME}, or request a lower desired_coverage." ) max_recorded_nominal = sf.max_recorded_nominal max_recorded_actual = sf.max_recorded_actual anchor_tail = (1.0 - max_recorded_nominal) / 2.0 q_low, q_high = np.quantile( samples, [anchor_tail, 1.0 - anchor_tail], axis=1, ) mean = jnp.mean(samples, axis=1) multiplier_s = float( scipy.stats.norm.ppf((1.0 + desired_coverage) / 2.0) / scipy.stats.norm.ppf((1.0 + max_recorded_actual) / 2.0) ) lower = mean - multiplier_s * (mean - q_low) upper = mean + multiplier_s * (q_high - mean) return CalibratedIntervals( lower=lower, upper=upper, desired_coverage=desired_coverage, achieved_actual_coverage=None, method="scale_family", applied_corrections={ "multiplier_s": multiplier_s, "max_recorded_nominal": float(max_recorded_nominal), "max_recorded_actual": float(max_recorded_actual), "q_low": q_low, "q_high": q_high, }, )