Source code for pytyche.viz._calibration

"""R(p) calibration-curve plot."""

from __future__ import annotations

from typing import Any

import matplotlib.pyplot as plt
from matplotlib.axes import Axes


[docs] def plot_calibration( calibrated_posterior: Any, reference: Any = None, ax: Axes | None = None, ) -> Axes: """Plot the attached R(p) coverage-remapping curve. Args: calibrated_posterior: A posterior carrying a :class:`~pytyche.calibrate.artifact.Calibration` artifact. reference: Optional posterior to plot alongside — its own R(p) curve when calibrated, else the nominal identity diagonal (raw claimed coverage). ax: Existing axes to draw into; ``None`` creates a new figure. Returns: The axes containing one line per posterior. Raises: ValueError: When *calibrated_posterior* carries no calibration artifact. """ calibration = getattr(calibrated_posterior, "calibration", None) if calibration is None: raise ValueError( "plot_calibration requires a calibrated posterior " "(calibration is None) — attach an artifact with " "apply_calibration first." ) if ax is None: _, ax = plt.subplots() curve = calibration.correction.coverage_correction ax.plot( curve.x_thresholds, curve.y_values, drawstyle="steps-post", marker="o", label="calibrated R(p)", ) if reference is not None: ref_calibration = getattr(reference, "calibration", None) if ref_calibration is None: ax.plot( [0.0, 1.0], [0.0, 1.0], linestyle="--", label="reference (nominal)", ) else: ref_curve = ref_calibration.correction.coverage_correction ax.plot( ref_curve.x_thresholds, ref_curve.y_values, drawstyle="steps-post", label="reference R(p)", ) ax.set_xlabel("nominal coverage p") ax.set_ylabel("actual coverage R(p)") ax.legend() return ax