Source code for pytyche.viz._segments
"""Segment-interval forest plot."""
from __future__ import annotations
from collections.abc import Sequence
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from pytyche.contracts import DiscoveredSegment
[docs]
def plot_segment_intervals(
segments: Sequence[DiscoveredSegment], ax: Axes | None = None
) -> Axes:
"""Forest plot of per-segment gate estimates with 80% intervals.
The interval bar spans ``gate_ci`` exactly; the marker sits at
``gate_estimate`` (drawn separately because a skewed posterior can
place the mean outside the percentile interval, which an
errorbar-around-the-mean cannot represent).
Args:
segments: Discovered segments to plot, one row per segment,
labeled by their rule descriptions.
ax: Existing axes to draw into; ``None`` creates a new figure.
Returns:
The axes containing the error-bar and marker artists.
"""
if ax is None:
_, ax = plt.subplots()
positions = range(len(segments))
midpoints = [(s.gate_ci[0] + s.gate_ci[1]) / 2.0 for s in segments]
halfwidths = [(s.gate_ci[1] - s.gate_ci[0]) / 2.0 for s in segments]
ax.errorbar(
midpoints, positions, xerr=halfwidths, fmt="none", capsize=3
)
ax.plot([s.gate_estimate for s in segments], positions, "o")
ax.set_yticks(positions, [s.rule.description for s in segments])
ax.invert_yaxis()
ax.axvline(0.0, color="grey", linewidth=0.8, linestyle=":")
ax.set_xlabel("gate effect (80% credible interval)")
return ax