"""Policy-tree rendering."""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import sklearn.tree
from matplotlib.axes import Axes
from matplotlib.patches import FancyBboxPatch, Patch, Rectangle
from pytyche._internal.extraction import extract_fit_arrays
if TYPE_CHECKING:
from pytyche.analysis import PolicyTreeResult
[docs]
def plot_policy_tree(
tree_policy: PolicyTreeResult, ax: Axes | None = None
) -> Axes:
"""Render the fitted policy tree as a policy, not a classifier.
Split nodes show only their condition (extraction-adapter feature
names — the encoding the tree was fit on). Leaves show the policy
content: the leading treatment with its Thompson allocation weight,
the remaining allocation, and the segment's population share and
stability score. Each leaf box is filled as a stacked bar — one
band per arm, width proportional to that arm's Thompson share,
one color per arm consistent across leaves (legend on the axes);
sklearn's classifier-style ``samples``/``value`` arrays never
appear.
Args:
tree_policy: A fitted :class:`~pytyche.analysis.PolicyTreeResult`.
ax: Existing axes to draw into; ``None`` creates a new figure.
Returns:
The axes containing the rendered tree-node annotations.
"""
if ax is None:
# Leaf labels are multi-line boxes; sklearn spaces leaves evenly,
# so readable width scales with the leaf count.
width = max(10.0, 2.6 * len(tree_policy.allocation_map))
_, ax = plt.subplots(figsize=(width, 6))
observed = tree_policy.observed
feature_names = list(extract_fit_arrays(observed).feature_names)
variant_names = (observed.control_name, *observed.treatment_names)
arm_colors = {name: f"C{i}" for i, name in enumerate(variant_names)}
annotations = sklearn.tree.plot_tree(
tree_policy.tree,
feature_names=feature_names,
impurity=False,
label="none",
filled=False,
rounded=True,
fontsize=8,
ax=ax,
)
internals = tree_policy.tree.tree_
# plot_tree's return mixes node boxes with the root's plain
# True/False arrow labels; only node annotations carry a bbox.
# sklearn assigns node ids in preorder and exports node boxes in
# the same recursion; pin that assumption loudly.
node_annotations = [
annotation
for annotation in annotations
if annotation.get_bbox_patch() is not None
]
if len(node_annotations) != internals.node_count:
raise RuntimeError(
f"plot_tree yielded {len(node_annotations)} node boxes for "
f"{internals.node_count} nodes — the node-order assumption "
"this renderer relabels by no longer holds."
)
segment_by_leaf = {seg.id: seg for seg in tree_policy.segments}
leaf_fills: list[tuple[FancyBboxPatch, list[tuple[str, float]]]] = []
for node_id, annotation in enumerate(node_annotations):
if internals.children_left[node_id] >= 0:
# Split node: keep the condition line, drop the class noise.
condition = annotation.get_text().split("\n")[0]
annotation.set_text(condition)
continue
allocation = tree_policy.allocation_map[node_id]
ordered = sorted(allocation.items(), key=lambda kv: -kv[1])
leader, leader_weight = ordered[0]
rest = " · ".join(f"{arm} {weight:.0%}" for arm, weight in ordered[1:])
segment = segment_by_leaf.get(node_id)
lines = [f"{leader} {leader_weight:.0%}"]
if rest:
lines.append(rest)
if segment is not None:
lines.append(
f"{segment.population_share:.0%} of visitors"
f" · stability {segment.stability_score:.2f}"
)
annotation.set_text("\n".join(lines))
bbox = annotation.get_bbox_patch()
if bbox is not None:
bbox.set(facecolor="none")
leaf_fills.append((bbox, ordered))
# Each leaf box doubles as its own allocation bar: stacked bands,
# leader first, widths proportional to the Thompson shares. Band
# placement needs the realized text-box geometry, which only exists
# after a draw lays the labels out. The rounded-box clip is frozen
# into data coordinates: clipping by the live patch breaks under
# ``savefig(bbox_inches="tight")``, where the axes' pixel origin
# shifts after the bands (zorder 1) have already been clipped
# against the patch's pre-shift display geometry.
fig = ax.figure
fig.canvas.draw()
to_data = ax.transData.inverted()
for bbox, ordered in leaf_fills:
box_path = bbox.get_path().transformed(bbox.get_transform())
clip_path = box_path.transformed(to_data)
extent = clip_path.get_extents()
left = extent.xmin
for arm, weight in ordered:
if weight <= 0:
continue
band = Rectangle(
(left, extent.ymin),
extent.width * weight,
extent.height,
facecolor=arm_colors[arm],
alpha=0.3,
linewidth=0,
)
band.set_clip_path(clip_path, ax.transData)
ax.add_patch(band)
left += extent.width * weight
ax.legend(
handles=[
Patch(facecolor=color, alpha=0.3, label=name)
for name, color in arm_colors.items()
],
loc="upper right",
frameon=False,
fontsize=8,
)
return ax