Source code for pytyche.viz._segments

"""Segment-interval forest plot."""

from __future__ import annotations

from collections.abc import Sequence

import matplotlib.pyplot as plt
from matplotlib.axes import Axes

from pytyche.contracts import DiscoveredSegment


[docs] def plot_segment_intervals( segments: Sequence[DiscoveredSegment], ax: Axes | None = None ) -> Axes: """Forest plot of per-segment gate estimates with 80% intervals. The interval bar spans ``gate_ci`` exactly; the marker sits at ``gate_estimate`` (drawn separately because a skewed posterior can place the mean outside the percentile interval, which an errorbar-around-the-mean cannot represent). Args: segments: Discovered segments to plot, one row per segment, labeled by their rule descriptions. ax: Existing axes to draw into; ``None`` creates a new figure. Returns: The axes containing the error-bar and marker artists. """ if ax is None: _, ax = plt.subplots() positions = range(len(segments)) midpoints = [(s.gate_ci[0] + s.gate_ci[1]) / 2.0 for s in segments] halfwidths = [(s.gate_ci[1] - s.gate_ci[0]) / 2.0 for s in segments] ax.errorbar( midpoints, positions, xerr=halfwidths, fmt="none", capsize=3 ) ax.plot([s.gate_estimate for s in segments], positions, "o") ax.set_yticks(positions, [s.rule.description for s in segments]) ax.invert_yaxis() ax.axvline(0.0, color="grey", linewidth=0.8, linestyle=":") ax.set_xlabel("gate effect (80% credible interval)") return ax