"""Policy-tree segmentation: ``fit_policy_tree`` and its ``PolicyTreeResult``.
The tree data model is sklearn-direct: ``PolicyTreeResult.tree`` is a
``sklearn.tree.DecisionTreeClassifier``. A future change may swap this for
a pytyche wrapper preserving the same predict / decision_path surface;
until tree persistence is needed, sklearn-direct holds.
:func:`fit_policy_tree` is the implementation behind
``posterior.fit_policy_tree(...)`` on the three posterior result types.
Labels are the shared best-arm rule applied to per-visitor posterior-MEAN
contrast vectors; each leaf becomes a :class:`~pytyche.contracts.DiscoveredSegment`
whose rule reproduces the leaf's membership EXACTLY through
:func:`pytyche.summarize.apply_rule` (a mismatch raises ``RuntimeError`` —
a silently-wrong rule must never ship).
"""
from __future__ import annotations
import dataclasses
import warnings
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from pytyche._internal.extraction import extract_fit_arrays
from pytyche.analysis._contrasts import (
AnyBCFResult,
best_arms,
contrast_samples,
require_observed,
)
from pytyche.analysis._intervals import credible_interval
from pytyche.analysis._thompson import arm_win_frequencies, iterative_floor_clip
from pytyche.contracts import (
VISITOR_SCHEMA,
ComparisonRule,
DiscoveredSegment,
EqRule,
InRule,
ObservedExperimentData,
RuleClause,
SegmentRule,
is_reserved_propensity_column,
)
from pytyche.summarize import apply_rule
#: The within-Thompson safety-net floor applied to each leaf's allocation —
#: the same default ``thompson_allocation`` exposes as its ``epsilon`` kwarg.
_ALLOCATION_EPSILON = 0.02
[docs]
@dataclasses.dataclass(frozen=True)
class PolicyTreeResult:
"""Policy-tree segmentation output.
Produced by ``posterior.fit_policy_tree(...)``. The tree partitions
feature space; treatment metadata (allocations, stability) is per-leaf
data layered on top, keyed by sklearn leaf id.
Fields:
tree: The fitted policy tree partitioning feature space (an
sklearn ``DecisionTreeClassifier``).
segments: One ``DiscoveredSegment`` per leaf, ordered by leaf id —
``segments[i].id == sorted(allocation_map.keys())[i]``.
allocation_map: Leaf id → per-variant Thompson weight dict; each
leaf's weights sum to 1.0 within 1e-6.
stability_scores: Leaf id → bootstrap-replicability score (the
fraction of bootstrap tree refits in which some leaf has
Jaccard overlap >= 0.5 with the leaf's member set). Keyed
exactly by the allocation_map leaves.
observed: The observed data the underlying posterior was fit on
(shared reference to ``posterior.observed``; no re-cloning).
"""
tree: DecisionTreeClassifier
segments: list[DiscoveredSegment]
allocation_map: dict[int, dict[str, float]]
stability_scores: dict[int, float]
observed: ObservedExperimentData
def __repr__(self) -> str:
lines = [
f"PolicyTreeResult — {len(self.segments)} segments"
f" over {self.observed.experiment_id}"
]
lines.extend(f" {seg._row()}" for seg in self.segments)
return "\n".join(lines)
def __post_init__(self) -> None:
for leaf_id, weights in self.allocation_map.items():
total = sum(weights.values())
if abs(total - 1.0) > 1e-6:
raise ValueError(
f"PolicyTreeResult: allocation_map[{leaf_id}] weights "
f"must sum to 1.0 within 1e-6, got {total}"
)
segment_ids = [s.id for s in self.segments]
leaf_ids = sorted(self.allocation_map.keys())
if segment_ids != leaf_ids:
raise ValueError(
f"PolicyTreeResult: segments ids {segment_ids} must equal "
f"the sorted allocation_map leaf ids {leaf_ids}"
)
if set(self.stability_scores.keys()) != set(self.allocation_map.keys()):
raise ValueError(
f"PolicyTreeResult: stability_scores keys "
f"{sorted(self.stability_scores.keys())} must equal the "
f"allocation_map leaf ids {leaf_ids}"
)
[docs]
def fit_policy_tree(
posterior: AnyBCFResult,
*,
max_depth: int = 3,
min_segment_share: float = 0.10,
n_bootstrap: int = 50,
bootstrap_seed: int = 0,
) -> PolicyTreeResult:
"""Discover interpretable segments from the posterior's per-visitor
treatment effects, by fitting a shallow decision tree.
Each visitor is labeled with the arm the posterior expects to be best
for them: their posterior-mean contrast vector, reduced to the
largest-lift treatment — or control, when no lift is positive. A
multiclass ``DecisionTreeClassifier`` is then fit on the visitors'
features (the same feature encoding the fit entry points use), so each
leaf groups visitors the model treats alike. Each leaf becomes a
:class:`~pytyche.contracts.DiscoveredSegment`:
- ``rule`` reproduces the leaf's membership exactly through
:func:`pytyche.summarize.apply_rule` (verified; mismatch raises
``RuntimeError``). A root-only tree yields one catch-all segment
with ``clauses == ()``.
- ``arm_best_probabilities`` are per-draw best-arm win frequencies over
the leaf's per-draw mean contrasts — the same code path
``thompson_allocation`` uses.
- ``gate_estimate`` / ``gate_ci`` (mean and the 80% interval) are
computed on the recommended arm's in-segment contrast draws; for
control-recommended leaves, on the maximum-mean contrast (the
strongest challenger — non-positive by construction). The interval
is the raw 10th/90th percentiles, or the calibrated remap when the
posterior carries an attached ``Calibration`` artifact.
- ``stability_score`` is bootstrap-replicability: the fraction of
``n_bootstrap`` row-resampled tree refits in which some leaf has
Jaccard overlap >= 0.5 with the segment's member set.
Args:
posterior: One of the three posterior result types, carrying
observed data (raises otherwise).
max_depth: Maximum tree depth.
min_segment_share: Minimum fraction of visitors per leaf (sklearn
``min_weight_fraction_leaf``).
n_bootstrap: Bootstrap refits behind ``stability_score``; ``0``
skips stability — every score is NaN and a ``UserWarning`` is
emitted.
bootstrap_seed: Seed for the bootstrap resampling RNG.
Returns:
:class:`PolicyTreeResult` with one segment per leaf (ordered by
sklearn leaf id); ``result.observed`` is ``posterior.observed`` by
identity.
Raises:
ValueError: When ``posterior.observed`` is ``None``, or when
``n_bootstrap`` is negative.
TypeError: When *posterior* is not an accepted result type.
RuntimeError: When an extracted segment rule fails to reproduce its
leaf's membership (rule extraction is wrong — fail loud).
"""
if n_bootstrap < 0:
raise ValueError(
f"fit_policy_tree: n_bootstrap must be >= 0, got {n_bootstrap} "
"(0 skips stability; positive values set the bootstrap count)"
)
observed = require_observed(posterior)
arrays = extract_fit_arrays(observed)
contrasts = contrast_samples(posterior) # (n, S, K-1)
variant_names = (observed.control_name, *observed.treatment_names)
n_arms = len(variant_names)
n = contrasts.shape[0]
# Labels: shared best-arm rule on per-visitor posterior-mean contrasts.
labels = best_arms(contrasts.mean(axis=1)) # (n,)
# random_state pinned for determinism (sklearn tie-breaking).
tree = DecisionTreeClassifier(
max_depth=max_depth,
min_weight_fraction_leaf=min_segment_share,
random_state=0,
)
tree.fit(arrays.X, labels)
leaf_of_row = tree.apply(arrays.X) # (n,) sklearn leaf node ids
leaf_ids = [int(leaf) for leaf in np.unique(leaf_of_row)] # ascending
# The pinned extraction row order: variant-list concat.
visitors = pd.concat(
[v.visitors for v in observed.variants], ignore_index=True
)
rules = _extract_leaf_rules(tree, arrays.feature_names, visitors, leaf_ids)
# Invariant: every rule reproduces its leaf's membership exactly.
for leaf_id in leaf_ids:
rule_mask = apply_rule(visitors, rules[leaf_id]).to_numpy()
leaf_mask = leaf_of_row == leaf_id
if not np.array_equal(rule_mask, leaf_mask):
mismatched = int(np.logical_xor(rule_mask, leaf_mask).sum())
raise RuntimeError(
f"fit_policy_tree: extracted rule for leaf {leaf_id} selects "
f"{int(rule_mask.sum())} rows vs the leaf's "
f"{int(leaf_mask.sum())} ({mismatched} rows differ) — rule "
f"extraction is wrong (rule: {rules[leaf_id]!r})"
)
if n_bootstrap == 0:
warnings.warn(
"fit_policy_tree: n_bootstrap=0 — bootstrap-replicability "
"stability was not computed; every stability_score is NaN.",
UserWarning,
stacklevel=2,
)
stability_scores = dict.fromkeys(leaf_ids, float("nan"))
else:
stability_scores = _bootstrap_stability(
arrays.X,
labels,
leaf_of_row,
leaf_ids,
max_depth=max_depth,
min_segment_share=min_segment_share,
n_bootstrap=n_bootstrap,
bootstrap_seed=bootstrap_seed,
)
segments: list[DiscoveredSegment] = []
allocation_map: dict[int, dict[str, float]] = {}
for leaf_id in leaf_ids:
members = leaf_of_row == leaf_id
seg_draws = contrasts[members].mean(axis=0) # (S, K-1)
frequencies = arm_win_frequencies(seg_draws, n_arms)
arm_best_probabilities = {
name: float(f)
for name, f in zip(variant_names, frequencies, strict=True)
}
weights = iterative_floor_clip(frequencies, _ALLOCATION_EPSILON)
allocation_map[leaf_id] = {
name: float(w)
for name, w in zip(variant_names, weights, strict=True)
}
recommended = _leaf_majority_label(tree, leaf_id)
if recommended == 0:
# Control-recommended: gate on the strongest challenger.
challenger = int(np.argmax(seg_draws.mean(axis=0)))
gate_draws = seg_draws[:, challenger]
else:
gate_draws = seg_draws[:, recommended - 1]
segments.append(
DiscoveredSegment(
id=leaf_id,
rule=rules[leaf_id],
gate_estimate=float(gate_draws.mean()),
gate_ci=credible_interval(
gate_draws, posterior.calibration
),
population_share=float(members.sum() / n),
stability_score=stability_scores[leaf_id],
arm_best_probabilities=arm_best_probabilities,
)
)
return PolicyTreeResult(
tree=tree,
segments=segments,
allocation_map=allocation_map,
stability_scores=stability_scores,
observed=observed,
)
# ---------------------------------------------------------------------------
# Rule extraction (tree paths -> contracts rule clauses)
# ---------------------------------------------------------------------------
@dataclasses.dataclass(frozen=True)
class _FeatureResolution:
"""How encoded feature names map back to original visitor columns.
Fields:
numeric_columns: Original numeric/bool feature columns — these pass
through ``pd.get_dummies`` unchanged, so the encoded name IS the
column name.
onehot_map: Encoded one-hot name (``f"{column}_{level}"``) ->
``(column, level)``. Built from the actual visitor columns and
their observed levels — never by splitting on underscores, which
both column names and levels may themselves contain.
ambiguous: One-hot names claimed by more than one (column, level)
pair; splitting on one of these cannot be resolved.
levels_by_column: Observed levels per categorical column, sorted —
index 0 is the level ``pd.get_dummies(drop_first=True)`` drops.
"""
numeric_columns: frozenset[str]
onehot_map: dict[str, tuple[str, str]]
ambiguous: frozenset[str]
levels_by_column: dict[str, tuple[str, ...]]
def _resolve_features(visitors: pd.DataFrame) -> _FeatureResolution:
"""Build the encoded-name resolution for the visitors frame's features."""
schema_columns = set(VISITOR_SCHEMA.keys())
feature_columns = [
column
for column in visitors.columns
if column not in schema_columns
and not is_reserved_propensity_column(column)
]
numeric_columns: set[str] = set()
onehot_map: dict[str, tuple[str, str]] = {}
ambiguous: set[str] = set()
levels_by_column: dict[str, tuple[str, ...]] = {}
for column in feature_columns:
if pd.api.types.is_numeric_dtype(visitors[column]):
# bool counts as numeric — matches the extraction adapter, where
# numeric/bool columns pass through get_dummies unchanged.
numeric_columns.add(column)
continue
levels = tuple(
sorted(str(v) for v in visitors[column].dropna().unique())
)
levels_by_column[column] = levels
for level in levels:
name = f"{column}_{level}"
if name in onehot_map:
ambiguous.add(name)
else:
onehot_map[name] = (column, level)
return _FeatureResolution(
numeric_columns=frozenset(numeric_columns),
onehot_map=onehot_map,
ambiguous=frozenset(ambiguous),
levels_by_column=levels_by_column,
)
def _leaf_paths(
tree: DecisionTreeClassifier,
) -> dict[int, list[tuple[int, float, bool]]]:
"""Root-to-leaf decision paths: leaf id -> [(feature, threshold, went_right)]."""
t = tree.tree_
paths: dict[int, list[tuple[int, float, bool]]] = {}
stack: list[tuple[int, list[tuple[int, float, bool]]]] = [(0, [])]
while stack:
node, path = stack.pop()
left = int(t.children_left[node])
if left == -1: # leaf
paths[node] = path
continue
right = int(t.children_right[node])
feature = int(t.feature[node])
threshold = float(t.threshold[node])
stack.append((left, [*path, (feature, threshold, False)]))
stack.append((right, [*path, (feature, threshold, True)]))
return paths
def _extract_leaf_rules(
tree: DecisionTreeClassifier,
feature_names: tuple[str, ...],
visitors: pd.DataFrame,
leaf_ids: list[int],
) -> dict[int, SegmentRule]:
"""One ``SegmentRule`` per leaf, decoded from the leaf's root path."""
paths = _leaf_paths(tree)
resolution = _resolve_features(visitors)
return {
leaf_id: _path_to_rule(paths[leaf_id], feature_names, resolution)
for leaf_id in leaf_ids
}
def _path_to_rule(
path: list[tuple[int, float, bool]],
feature_names: tuple[str, ...],
resolution: _FeatureResolution,
) -> SegmentRule:
"""Fold one root-to-leaf path into a ``SegmentRule``.
Numeric splits keep the tightest bound per direction (min of the
``lte`` uppers, max of the ``gt`` lowers). One-hot splits intersect
per original categorical column: a right branch pins ``EqRule(column,
level)`` (which wins over any ``InRule`` it satisfies); left branches
accumulate excluded levels, yielding ``InRule`` over the remaining
observed levels INCLUDING the ``get_dummies``-dropped first level.
"""
if not path: # root-only tree: the catch-all rule
return SegmentRule(description="all visitors", clauses=())
lte_bounds: dict[str, float] = {}
gt_bounds: dict[str, float] = {}
eq_level: dict[str, str] = {}
excluded_levels: dict[str, set[str]] = {}
for feature_idx, threshold, went_right in path:
name = feature_names[feature_idx]
if name in resolution.numeric_columns:
if went_right: # name > threshold — tightest lower bound is max
previous = gt_bounds.get(name)
gt_bounds[name] = (
threshold if previous is None else max(previous, threshold)
)
else: # name <= threshold — tightest upper bound is min
previous = lte_bounds.get(name)
lte_bounds[name] = (
threshold if previous is None else min(previous, threshold)
)
continue
if name in resolution.ambiguous:
raise RuntimeError(
f"fit_policy_tree: split feature {name!r} resolves to more "
"than one (column, level) pair of the visitors frame; the "
"segment rule cannot be extracted unambiguously."
)
if name not in resolution.onehot_map:
raise RuntimeError(
f"fit_policy_tree: split feature {name!r} matches neither a "
"numeric visitor column nor a one-hot encoding of a "
"categorical visitor column."
)
column, level = resolution.onehot_map[name]
if went_right: # one-hot == 1 -> column == level
if column in eq_level and eq_level[column] != level:
raise RuntimeError(
f"fit_policy_tree: contradictory path requires "
f"{column!r} == {eq_level[column]!r} and == {level!r}."
)
eq_level[column] = level
else: # one-hot == 0 -> column != level
excluded_levels.setdefault(column, set()).add(level)
clauses: list[RuleClause] = []
for name, bound in lte_bounds.items():
clauses.append(ComparisonRule(name, "lte", bound))
for name, bound in gt_bounds.items():
clauses.append(ComparisonRule(name, "gt", bound))
for column, level in eq_level.items():
clauses.append(EqRule(column, level))
for column, excluded in excluded_levels.items():
if column in eq_level:
# EqRule wins over any InRule it satisfies. (A contradictory
# EqRule-on-an-excluded-level path cannot produce a non-empty
# sklearn leaf; the membership invariant would catch it.)
continue
allowed = tuple(
sorted(set(resolution.levels_by_column[column]) - excluded)
)
clauses.append(InRule(column, allowed))
# SegmentRule canonical-sorts clauses itself; pre-sort with the same key
# so the description reads in the canonical clause order.
ordered = sorted(
clauses, key=lambda c: (c.feature, type(c).__name__, repr(c))
)
description = " AND ".join(_clause_text(clause) for clause in ordered)
return SegmentRule(description=description, clauses=tuple(ordered))
_COMPARISON_TEXT = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}
def _clause_text(clause: RuleClause) -> str:
"""Short human-readable rendering of one rule clause."""
if isinstance(clause, EqRule):
return f"{clause.feature} == {clause.value}"
if isinstance(clause, InRule):
return f"{clause.feature} in {{{', '.join(clause.values)}}}"
if isinstance(clause, ComparisonRule):
operator = _COMPARISON_TEXT[clause.operator]
return f"{clause.feature} {operator} {clause.threshold:g}"
raise TypeError(
f"fit_policy_tree never produces {type(clause).__name__} clauses"
)
# ---------------------------------------------------------------------------
# Per-leaf recommendation and bootstrap stability
# ---------------------------------------------------------------------------
def _leaf_majority_label(tree: DecisionTreeClassifier, leaf_id: int) -> int:
"""The leaf's majority training label (the tree's class prediction)."""
class_counts = tree.tree_.value[leaf_id, 0, :]
return int(tree.classes_[int(np.argmax(class_counts))])
def _bootstrap_stability(
X: np.ndarray,
labels: np.ndarray,
leaf_of_row: np.ndarray,
leaf_ids: list[int],
*,
max_depth: int,
min_segment_share: float,
n_bootstrap: int,
bootstrap_seed: int,
) -> dict[int, float]:
"""Bootstrap-replicability stability score per original leaf.
For each of *n_bootstrap* row-resamples (with replacement): refit a
tree with the SAME hyperparameters, route the ORIGINAL rows through it,
and count a success for an original leaf when SOME bootstrap leaf's
member set has Jaccard overlap >= 0.5 with it. The score is
``successes / n_bootstrap``.
"""
rng = np.random.default_rng(bootstrap_seed)
n = X.shape[0]
member_masks = {leaf_id: leaf_of_row == leaf_id for leaf_id in leaf_ids}
successes = dict.fromkeys(leaf_ids, 0)
for _ in range(n_bootstrap):
idx = rng.integers(0, n, size=n)
refit = DecisionTreeClassifier(
max_depth=max_depth,
min_weight_fraction_leaf=min_segment_share,
random_state=0,
)
refit.fit(X[idx], labels[idx])
refit_leaf_of_row = refit.apply(X) # original rows, bootstrap tree
refit_masks = [
refit_leaf_of_row == leaf for leaf in np.unique(refit_leaf_of_row)
]
for leaf_id, mask in member_masks.items():
if any(_jaccard(mask, m) >= 0.5 for m in refit_masks):
successes[leaf_id] += 1
return {
leaf_id: successes[leaf_id] / n_bootstrap for leaf_id in leaf_ids
}
def _jaccard(a: np.ndarray, b: np.ndarray) -> float:
"""Jaccard overlap of two boolean member masks (leaves are non-empty)."""
intersection = int(np.logical_and(a, b).sum())
union = int(np.logical_or(a, b).sum())
return intersection / union