"""Layered calibration deployment primitive — R(p) below ceiling, scale-family above.
The deployment path does NOT read the fitted ``c`` parameter; ``c`` is
persisted only for the matplotlib diagnostic plot. The past-ceiling regime
anchors on the SBC-measured pair ``(max_recorded_nominal, max_recorded_actual)``
and scales per-test-point empirical anchor quantiles by the Gaussian z-score
ratio. This preserves the per-point asymmetry of hurdle CATEs that a
fully-Gaussian-σ inflation would average away.
"""
from __future__ import annotations
import json
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Protocol
import jax.numpy as jnp
import numpy as np
import scipy.stats
from pytyche.calibrate._correction_io import invert_correction
class _HasCateDraws(Protocol):
"""Structural type for the result argument of :func:`apply_layered_calibration`.
Both ``HurdleBCFResult`` and ad-hoc test/eval stubs (``SimpleNamespace``)
that carry the posterior CATE draws satisfy this protocol.
"""
rpv_cate_samples: np.ndarray
_COVERAGE_FILENAME = "coverage_correction.json"
_SCALE_FAMILY_FILENAME = "scale_family_correction.json"
[docs]
@dataclass(frozen=True)
class CoverageCorrection:
"""Frozen view of a serialized R(p) isotonic correction.
Mirrors the ``coverage_correction.json`` schema produced by
``scripts/fit_sbc_correction.py``: a pair of monotone arrays defining
the isotonic step function ``nominal → actual``.
"""
x_thresholds: tuple[float, ...]
y_values: tuple[float, ...]
[docs]
@dataclass(frozen=True)
class ScaleFamilyCorrection:
"""Frozen view of a fitted scale-family correction.
``c`` is the Gaussian-scale LSQ fit, persisted for diagnostic-plot
rendering only. :func:`apply_layered_calibration` deploys against the
empirical anchor ``(max_recorded_nominal, max_recorded_actual)``.
"""
c: float
max_recorded_nominal: float
max_recorded_actual: float
fit_method: str
n_points_fit: int
[docs]
@dataclass(frozen=True)
class CalibratedIntervals:
"""Per-test-point calibrated intervals + provenance.
Returned by :func:`apply_layered_calibration`. ``method`` records
which regime fired (``"r_of_p"`` or ``"scale_family"``);
``applied_corrections`` carries regime-specific provenance keys.
"""
lower: np.ndarray
upper: np.ndarray
desired_coverage: float
achieved_actual_coverage: float | None
method: Literal["r_of_p", "scale_family"]
applied_corrections: dict
[docs]
@dataclass(frozen=True)
class LayeredCalibrationCorrection:
"""Discriminated container for a fitted layered SBC correction.
The ``mode`` field selects which deployment regimes are supported:
- ``"layered"``: both R(p) (below ceiling) and scale-family
(past ceiling) regimes available. ``scale_family`` MUST be non-None.
- ``"r_of_p_only"``: only R(p) is available. Past-ceiling deployment
raises rather than silently degrading. ``scale_family`` MUST be None.
The construction-time invariant is validated in ``__post_init__``
and at load time by :meth:`from_directory`. No silent fallback: a
caller wanting r_of_p_only semantics on a sweep that DOES have a
scale-family file must pass ``mode="r_of_p_only"`` explicitly and
will see a ``UserWarning`` noting the ignored file.
"""
mode: Literal["layered", "r_of_p_only"]
coverage_correction: CoverageCorrection
scale_family: ScaleFamilyCorrection | None
def __post_init__(self) -> None:
"""Validate the ``mode`` / ``scale_family`` consistency invariant.
``scale_family`` MUST be non-None iff ``mode == "layered"``. Both
directions of inconsistency raise ``ValueError`` naming the
offending fields (per spec "construction rejects
mode/scale_family inconsistency").
"""
if self.mode == "layered" and self.scale_family is None:
raise ValueError(
"Inconsistent mode / scale_family: mode='layered' requires "
"a non-None scale_family, but scale_family is None."
)
if self.mode == "r_of_p_only" and self.scale_family is not None:
raise ValueError(
"Inconsistent mode / scale_family: mode='r_of_p_only' "
"requires scale_family to be None, but a non-None "
"scale_family was provided."
)
[docs]
@classmethod
def from_directory(
cls,
path: str | Path,
*,
mode: Literal["layered", "r_of_p_only"],
) -> LayeredCalibrationCorrection:
"""Load a correction pair from a directory.
Parameters
----------
path
Directory containing ``coverage_correction.json`` and,
for layered mode, ``scale_family_correction.json``.
mode
Explicit deployment mode. No default; the caller must declare
their intent at the call site rather than receive silent fallback.
Returns
-------
LayeredCalibrationCorrection
Loaded correction with typed inner dataclasses.
Raises
------
FileNotFoundError
If a required JSON file is missing for the requested mode.
The message names the missing file.
Warns
-----
UserWarning
If ``mode == "r_of_p_only"`` and ``scale_family_correction.json``
is present in the directory — the file is being ignored due to
the explicit mode choice (visible signal, not silent).
"""
directory = Path(path)
coverage_path = directory / _COVERAGE_FILENAME
scale_family_path = directory / _SCALE_FAMILY_FILENAME
if not coverage_path.exists():
raise FileNotFoundError(
f"Required calibration artifact not found: {coverage_path} "
f"({_COVERAGE_FILENAME} is required for both modes)."
)
coverage_dict = json.loads(coverage_path.read_text())
coverage = CoverageCorrection(
x_thresholds=tuple(float(v) for v in coverage_dict["x_thresholds"]),
y_values=tuple(float(v) for v in coverage_dict["y_values"]),
)
if mode == "layered":
if not scale_family_path.exists():
raise FileNotFoundError(
f"Required calibration artifact not found: "
f"{scale_family_path} ({_SCALE_FAMILY_FILENAME} is "
f"required when mode='layered')."
)
sf_dict = json.loads(scale_family_path.read_text())
scale_family = ScaleFamilyCorrection(
c=float(sf_dict["c"]),
max_recorded_nominal=float(sf_dict["max_recorded_nominal"]),
max_recorded_actual=float(sf_dict["max_recorded_actual"]),
fit_method=str(sf_dict["fit_method"]),
n_points_fit=int(sf_dict["n_points_fit"]),
)
return cls(
mode="layered",
coverage_correction=coverage,
scale_family=scale_family,
)
# mode == "r_of_p_only"
if scale_family_path.exists():
warnings.warn(
f"{_SCALE_FAMILY_FILENAME} is present in {directory} but is "
f"being ignored because mode='r_of_p_only' was requested "
f"explicitly. Past-ceiling deployment will not be available "
f"from this correction.",
UserWarning,
stacklevel=2,
)
return cls(
mode="r_of_p_only",
coverage_correction=coverage,
scale_family=None,
)
[docs]
def apply_layered_calibration(
result: _HasCateDraws,
correction: LayeredCalibrationCorrection,
desired_coverage: float,
) -> CalibratedIntervals:
"""Apply a fitted layered calibration to posterior CATE draws.
Stateless: does not mutate ``result``. Reads only
``result.rpv_cate_samples`` (shape ``(n, S)``) — the per-test-point
posterior draws produced by a joint hurdle BCF fit.
Regime selection (hard boundary, no smoothing — see spec
"hard boundary switch with no smoothing"):
- ``desired_coverage <= ceiling`` → R(p) regime: invert the isotonic
curve to find nominal ``n*`` such that ``R(n*) == desired_coverage``,
then return the symmetric ``(1-n*)/2`` / ``1-(1-n*)/2`` quantile
pair on the per-test-point draws.
- ``desired_coverage > ceiling`` → scale-family regime: anchor on the
empirical ``(q_low, q_high)`` at ``max_recorded_nominal`` and inflate
both half-widths around the per-point posterior mean by the Gaussian
z-score ratio ``s``. This preserves per-point empirical asymmetry that
a fully-Gaussian-σ inflation would average away.
The ceiling is ``correction.scale_family.max_recorded_actual`` for
layered mode and ``max(correction.coverage_correction.y_values)`` for
r_of_p_only mode.
Parameters
----------
result
Fit result providing ``rpv_cate_samples: ndarray[(n, S)]``.
correction
Fitted layered calibration loaded via
:meth:`LayeredCalibrationCorrection.from_directory`.
desired_coverage
Requested coverage level in the open interval ``(0, 1)``.
Returns
-------
CalibratedIntervals
Calibrated per-test-point endpoints plus provenance.
Raises
------
ValueError
If ``desired_coverage`` is outside ``(0, 1)``.
If ``correction.mode == "r_of_p_only"`` and ``desired_coverage``
exceeds the R(p) ceiling — past-ceiling calibration requires a
scale-family correction.
"""
if not (0.0 < desired_coverage < 1.0):
raise ValueError(
f"desired_coverage must lie in the open interval (0, 1); "
f"got {desired_coverage}."
)
samples = result.rpv_cate_samples
sf = correction.scale_family # invariant: non-None iff mode == "layered"
if sf is not None:
ceiling = sf.max_recorded_actual
else:
ceiling = float(max(correction.coverage_correction.y_values))
if desired_coverage <= ceiling:
# R(p) regime: invert isotonic to find nominal n*, read quantiles.
cov = correction.coverage_correction
nominal_star = float(
invert_correction(
np.asarray(cov.x_thresholds),
np.asarray(cov.y_values),
desired_coverage,
).item()
)
tail = (1.0 - nominal_star) / 2.0
lower, upper = jnp.quantile(samples, jnp.asarray([tail, 1.0 - tail]), axis=1)
return CalibratedIntervals(
lower=lower,
upper=upper,
desired_coverage=desired_coverage,
achieved_actual_coverage=None,
method="r_of_p",
applied_corrections={"nominal_star": nominal_star},
)
# desired_coverage > ceiling → scale-family regime.
if sf is None:
raise ValueError(
f"desired_coverage={desired_coverage} exceeds the R(p) ceiling "
f"({ceiling}); past-ceiling calibration is not available in "
f"mode='r_of_p_only'. Re-fit the SBC sweep to produce a "
f"{_SCALE_FAMILY_FILENAME}, or request a lower desired_coverage."
)
max_recorded_nominal = sf.max_recorded_nominal
max_recorded_actual = sf.max_recorded_actual
anchor_tail = (1.0 - max_recorded_nominal) / 2.0
q_low, q_high = np.quantile(
samples,
[anchor_tail, 1.0 - anchor_tail],
axis=1,
)
mean = jnp.mean(samples, axis=1)
multiplier_s = float(
scipy.stats.norm.ppf((1.0 + desired_coverage) / 2.0)
/ scipy.stats.norm.ppf((1.0 + max_recorded_actual) / 2.0)
)
lower = mean - multiplier_s * (mean - q_low)
upper = mean + multiplier_s * (q_high - mean)
return CalibratedIntervals(
lower=lower,
upper=upper,
desired_coverage=desired_coverage,
achieved_actual_coverage=None,
method="scale_family",
applied_corrections={
"multiplier_s": multiplier_s,
"max_recorded_nominal": float(max_recorded_nominal),
"max_recorded_actual": float(max_recorded_actual),
"q_low": q_low,
"q_high": q_high,
},
)