Source code for pytyche.viz._policy_tree

"""Policy-tree rendering."""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import sklearn.tree
from matplotlib.axes import Axes
from matplotlib.patches import FancyBboxPatch, Patch, Rectangle

from pytyche._internal.extraction import extract_fit_arrays

if TYPE_CHECKING:
    from pytyche.analysis import PolicyTreeResult


[docs] def plot_policy_tree( tree_policy: PolicyTreeResult, ax: Axes | None = None ) -> Axes: """Render the fitted policy tree as a policy, not a classifier. Split nodes show only their condition (extraction-adapter feature names — the encoding the tree was fit on). Leaves show the policy content: the leading treatment with its Thompson allocation weight, the remaining allocation, and the segment's population share and stability score. Each leaf box is filled as a stacked bar — one band per arm, width proportional to that arm's Thompson share, one color per arm consistent across leaves (legend on the axes); sklearn's classifier-style ``samples``/``value`` arrays never appear. Args: tree_policy: A fitted :class:`~pytyche.analysis.PolicyTreeResult`. ax: Existing axes to draw into; ``None`` creates a new figure. Returns: The axes containing the rendered tree-node annotations. """ if ax is None: # Leaf labels are multi-line boxes; sklearn spaces leaves evenly, # so readable width scales with the leaf count. width = max(10.0, 2.6 * len(tree_policy.allocation_map)) _, ax = plt.subplots(figsize=(width, 6)) observed = tree_policy.observed feature_names = list(extract_fit_arrays(observed).feature_names) variant_names = (observed.control_name, *observed.treatment_names) arm_colors = {name: f"C{i}" for i, name in enumerate(variant_names)} annotations = sklearn.tree.plot_tree( tree_policy.tree, feature_names=feature_names, impurity=False, label="none", filled=False, rounded=True, fontsize=8, ax=ax, ) internals = tree_policy.tree.tree_ # plot_tree's return mixes node boxes with the root's plain # True/False arrow labels; only node annotations carry a bbox. # sklearn assigns node ids in preorder and exports node boxes in # the same recursion; pin that assumption loudly. node_annotations = [ annotation for annotation in annotations if annotation.get_bbox_patch() is not None ] if len(node_annotations) != internals.node_count: raise RuntimeError( f"plot_tree yielded {len(node_annotations)} node boxes for " f"{internals.node_count} nodes — the node-order assumption " "this renderer relabels by no longer holds." ) segment_by_leaf = {seg.id: seg for seg in tree_policy.segments} leaf_fills: list[tuple[FancyBboxPatch, list[tuple[str, float]]]] = [] for node_id, annotation in enumerate(node_annotations): if internals.children_left[node_id] >= 0: # Split node: keep the condition line, drop the class noise. condition = annotation.get_text().split("\n")[0] annotation.set_text(condition) continue allocation = tree_policy.allocation_map[node_id] ordered = sorted(allocation.items(), key=lambda kv: -kv[1]) leader, leader_weight = ordered[0] rest = " · ".join(f"{arm} {weight:.0%}" for arm, weight in ordered[1:]) segment = segment_by_leaf.get(node_id) lines = [f"{leader} {leader_weight:.0%}"] if rest: lines.append(rest) if segment is not None: lines.append( f"{segment.population_share:.0%} of visitors" f" · stability {segment.stability_score:.2f}" ) annotation.set_text("\n".join(lines)) bbox = annotation.get_bbox_patch() if bbox is not None: bbox.set(facecolor="none") leaf_fills.append((bbox, ordered)) # Each leaf box doubles as its own allocation bar: stacked bands, # leader first, widths proportional to the Thompson shares. Band # placement needs the realized text-box geometry, which only exists # after a draw lays the labels out. The rounded-box clip is frozen # into data coordinates: clipping by the live patch breaks under # ``savefig(bbox_inches="tight")``, where the axes' pixel origin # shifts after the bands (zorder 1) have already been clipped # against the patch's pre-shift display geometry. fig = ax.figure fig.canvas.draw() to_data = ax.transData.inverted() for bbox, ordered in leaf_fills: box_path = bbox.get_path().transformed(bbox.get_transform()) clip_path = box_path.transformed(to_data) extent = clip_path.get_extents() left = extent.xmin for arm, weight in ordered: if weight <= 0: continue band = Rectangle( (left, extent.ymin), extent.width * weight, extent.height, facecolor=arm_colors[arm], alpha=0.3, linewidth=0, ) band.set_clip_path(clip_path, ax.transData) ax.add_patch(band) left += extent.width * weight ax.legend( handles=[ Patch(facecolor=color, alpha=0.3, label=name) for name, color in arm_colors.items() ], loc="upper right", frameon=False, fontsize=8, ) return ax