Source code for pytyche.analysis._policy_tree

"""Policy-tree segmentation: ``fit_policy_tree`` and its ``PolicyTreeResult``.

The tree data model is sklearn-direct: ``PolicyTreeResult.tree`` is a
``sklearn.tree.DecisionTreeClassifier``.  A future change may swap this for
a pytyche wrapper preserving the same predict / decision_path surface;
until tree persistence is needed, sklearn-direct holds.

:func:`fit_policy_tree` is the implementation behind
``posterior.fit_policy_tree(...)`` on the three posterior result types.
Labels are the shared best-arm rule applied to per-visitor posterior-MEAN
contrast vectors; each leaf becomes a :class:`~pytyche.contracts.DiscoveredSegment`
whose rule reproduces the leaf's membership EXACTLY through
:func:`pytyche.summarize.apply_rule` (a mismatch raises ``RuntimeError`` —
a silently-wrong rule must never ship).
"""

from __future__ import annotations

import dataclasses
import warnings

import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier

from pytyche._internal.extraction import extract_fit_arrays
from pytyche.analysis._contrasts import (
    AnyBCFResult,
    best_arms,
    contrast_samples,
    require_observed,
)
from pytyche.analysis._intervals import credible_interval
from pytyche.analysis._thompson import arm_win_frequencies, iterative_floor_clip
from pytyche.contracts import (
    VISITOR_SCHEMA,
    ComparisonRule,
    DiscoveredSegment,
    EqRule,
    InRule,
    ObservedExperimentData,
    RuleClause,
    SegmentRule,
    is_reserved_propensity_column,
)
from pytyche.summarize import apply_rule

#: The within-Thompson safety-net floor applied to each leaf's allocation —
#: the same default ``thompson_allocation`` exposes as its ``epsilon`` kwarg.
_ALLOCATION_EPSILON = 0.02


[docs] @dataclasses.dataclass(frozen=True) class PolicyTreeResult: """Policy-tree segmentation output. Produced by ``posterior.fit_policy_tree(...)``. The tree partitions feature space; treatment metadata (allocations, stability) is per-leaf data layered on top, keyed by sklearn leaf id. Fields: tree: The fitted policy tree partitioning feature space (an sklearn ``DecisionTreeClassifier``). segments: One ``DiscoveredSegment`` per leaf, ordered by leaf id — ``segments[i].id == sorted(allocation_map.keys())[i]``. allocation_map: Leaf id → per-variant Thompson weight dict; each leaf's weights sum to 1.0 within 1e-6. stability_scores: Leaf id → bootstrap-replicability score (the fraction of bootstrap tree refits in which some leaf has Jaccard overlap >= 0.5 with the leaf's member set). Keyed exactly by the allocation_map leaves. observed: The observed data the underlying posterior was fit on (shared reference to ``posterior.observed``; no re-cloning). """ tree: DecisionTreeClassifier segments: list[DiscoveredSegment] allocation_map: dict[int, dict[str, float]] stability_scores: dict[int, float] observed: ObservedExperimentData def __repr__(self) -> str: lines = [ f"PolicyTreeResult — {len(self.segments)} segments" f" over {self.observed.experiment_id}" ] lines.extend(f" {seg._row()}" for seg in self.segments) return "\n".join(lines) def __post_init__(self) -> None: for leaf_id, weights in self.allocation_map.items(): total = sum(weights.values()) if abs(total - 1.0) > 1e-6: raise ValueError( f"PolicyTreeResult: allocation_map[{leaf_id}] weights " f"must sum to 1.0 within 1e-6, got {total}" ) segment_ids = [s.id for s in self.segments] leaf_ids = sorted(self.allocation_map.keys()) if segment_ids != leaf_ids: raise ValueError( f"PolicyTreeResult: segments ids {segment_ids} must equal " f"the sorted allocation_map leaf ids {leaf_ids}" ) if set(self.stability_scores.keys()) != set(self.allocation_map.keys()): raise ValueError( f"PolicyTreeResult: stability_scores keys " f"{sorted(self.stability_scores.keys())} must equal the " f"allocation_map leaf ids {leaf_ids}" )
[docs] def fit_policy_tree( posterior: AnyBCFResult, *, max_depth: int = 3, min_segment_share: float = 0.10, n_bootstrap: int = 50, bootstrap_seed: int = 0, ) -> PolicyTreeResult: """Discover interpretable segments from the posterior's per-visitor treatment effects, by fitting a shallow decision tree. Each visitor is labeled with the arm the posterior expects to be best for them: their posterior-mean contrast vector, reduced to the largest-lift treatment — or control, when no lift is positive. A multiclass ``DecisionTreeClassifier`` is then fit on the visitors' features (the same feature encoding the fit entry points use), so each leaf groups visitors the model treats alike. Each leaf becomes a :class:`~pytyche.contracts.DiscoveredSegment`: - ``rule`` reproduces the leaf's membership exactly through :func:`pytyche.summarize.apply_rule` (verified; mismatch raises ``RuntimeError``). A root-only tree yields one catch-all segment with ``clauses == ()``. - ``arm_best_probabilities`` are per-draw best-arm win frequencies over the leaf's per-draw mean contrasts — the same code path ``thompson_allocation`` uses. - ``gate_estimate`` / ``gate_ci`` (mean and the 80% interval) are computed on the recommended arm's in-segment contrast draws; for control-recommended leaves, on the maximum-mean contrast (the strongest challenger — non-positive by construction). The interval is the raw 10th/90th percentiles, or the calibrated remap when the posterior carries an attached ``Calibration`` artifact. - ``stability_score`` is bootstrap-replicability: the fraction of ``n_bootstrap`` row-resampled tree refits in which some leaf has Jaccard overlap >= 0.5 with the segment's member set. Args: posterior: One of the three posterior result types, carrying observed data (raises otherwise). max_depth: Maximum tree depth. min_segment_share: Minimum fraction of visitors per leaf (sklearn ``min_weight_fraction_leaf``). n_bootstrap: Bootstrap refits behind ``stability_score``; ``0`` skips stability — every score is NaN and a ``UserWarning`` is emitted. bootstrap_seed: Seed for the bootstrap resampling RNG. Returns: :class:`PolicyTreeResult` with one segment per leaf (ordered by sklearn leaf id); ``result.observed`` is ``posterior.observed`` by identity. Raises: ValueError: When ``posterior.observed`` is ``None``, or when ``n_bootstrap`` is negative. TypeError: When *posterior* is not an accepted result type. RuntimeError: When an extracted segment rule fails to reproduce its leaf's membership (rule extraction is wrong — fail loud). """ if n_bootstrap < 0: raise ValueError( f"fit_policy_tree: n_bootstrap must be >= 0, got {n_bootstrap} " "(0 skips stability; positive values set the bootstrap count)" ) observed = require_observed(posterior) arrays = extract_fit_arrays(observed) contrasts = contrast_samples(posterior) # (n, S, K-1) variant_names = (observed.control_name, *observed.treatment_names) n_arms = len(variant_names) n = contrasts.shape[0] # Labels: shared best-arm rule on per-visitor posterior-mean contrasts. labels = best_arms(contrasts.mean(axis=1)) # (n,) # random_state pinned for determinism (sklearn tie-breaking). tree = DecisionTreeClassifier( max_depth=max_depth, min_weight_fraction_leaf=min_segment_share, random_state=0, ) tree.fit(arrays.X, labels) leaf_of_row = tree.apply(arrays.X) # (n,) sklearn leaf node ids leaf_ids = [int(leaf) for leaf in np.unique(leaf_of_row)] # ascending # The pinned extraction row order: variant-list concat. visitors = pd.concat( [v.visitors for v in observed.variants], ignore_index=True ) rules = _extract_leaf_rules(tree, arrays.feature_names, visitors, leaf_ids) # Invariant: every rule reproduces its leaf's membership exactly. for leaf_id in leaf_ids: rule_mask = apply_rule(visitors, rules[leaf_id]).to_numpy() leaf_mask = leaf_of_row == leaf_id if not np.array_equal(rule_mask, leaf_mask): mismatched = int(np.logical_xor(rule_mask, leaf_mask).sum()) raise RuntimeError( f"fit_policy_tree: extracted rule for leaf {leaf_id} selects " f"{int(rule_mask.sum())} rows vs the leaf's " f"{int(leaf_mask.sum())} ({mismatched} rows differ) — rule " f"extraction is wrong (rule: {rules[leaf_id]!r})" ) if n_bootstrap == 0: warnings.warn( "fit_policy_tree: n_bootstrap=0 — bootstrap-replicability " "stability was not computed; every stability_score is NaN.", UserWarning, stacklevel=2, ) stability_scores = dict.fromkeys(leaf_ids, float("nan")) else: stability_scores = _bootstrap_stability( arrays.X, labels, leaf_of_row, leaf_ids, max_depth=max_depth, min_segment_share=min_segment_share, n_bootstrap=n_bootstrap, bootstrap_seed=bootstrap_seed, ) segments: list[DiscoveredSegment] = [] allocation_map: dict[int, dict[str, float]] = {} for leaf_id in leaf_ids: members = leaf_of_row == leaf_id seg_draws = contrasts[members].mean(axis=0) # (S, K-1) frequencies = arm_win_frequencies(seg_draws, n_arms) arm_best_probabilities = { name: float(f) for name, f in zip(variant_names, frequencies, strict=True) } weights = iterative_floor_clip(frequencies, _ALLOCATION_EPSILON) allocation_map[leaf_id] = { name: float(w) for name, w in zip(variant_names, weights, strict=True) } recommended = _leaf_majority_label(tree, leaf_id) if recommended == 0: # Control-recommended: gate on the strongest challenger. challenger = int(np.argmax(seg_draws.mean(axis=0))) gate_draws = seg_draws[:, challenger] else: gate_draws = seg_draws[:, recommended - 1] segments.append( DiscoveredSegment( id=leaf_id, rule=rules[leaf_id], gate_estimate=float(gate_draws.mean()), gate_ci=credible_interval( gate_draws, posterior.calibration ), population_share=float(members.sum() / n), stability_score=stability_scores[leaf_id], arm_best_probabilities=arm_best_probabilities, ) ) return PolicyTreeResult( tree=tree, segments=segments, allocation_map=allocation_map, stability_scores=stability_scores, observed=observed, )
# --------------------------------------------------------------------------- # Rule extraction (tree paths -> contracts rule clauses) # --------------------------------------------------------------------------- @dataclasses.dataclass(frozen=True) class _FeatureResolution: """How encoded feature names map back to original visitor columns. Fields: numeric_columns: Original numeric/bool feature columns — these pass through ``pd.get_dummies`` unchanged, so the encoded name IS the column name. onehot_map: Encoded one-hot name (``f"{column}_{level}"``) -> ``(column, level)``. Built from the actual visitor columns and their observed levels — never by splitting on underscores, which both column names and levels may themselves contain. ambiguous: One-hot names claimed by more than one (column, level) pair; splitting on one of these cannot be resolved. levels_by_column: Observed levels per categorical column, sorted — index 0 is the level ``pd.get_dummies(drop_first=True)`` drops. """ numeric_columns: frozenset[str] onehot_map: dict[str, tuple[str, str]] ambiguous: frozenset[str] levels_by_column: dict[str, tuple[str, ...]] def _resolve_features(visitors: pd.DataFrame) -> _FeatureResolution: """Build the encoded-name resolution for the visitors frame's features.""" schema_columns = set(VISITOR_SCHEMA.keys()) feature_columns = [ column for column in visitors.columns if column not in schema_columns and not is_reserved_propensity_column(column) ] numeric_columns: set[str] = set() onehot_map: dict[str, tuple[str, str]] = {} ambiguous: set[str] = set() levels_by_column: dict[str, tuple[str, ...]] = {} for column in feature_columns: if pd.api.types.is_numeric_dtype(visitors[column]): # bool counts as numeric — matches the extraction adapter, where # numeric/bool columns pass through get_dummies unchanged. numeric_columns.add(column) continue levels = tuple( sorted(str(v) for v in visitors[column].dropna().unique()) ) levels_by_column[column] = levels for level in levels: name = f"{column}_{level}" if name in onehot_map: ambiguous.add(name) else: onehot_map[name] = (column, level) return _FeatureResolution( numeric_columns=frozenset(numeric_columns), onehot_map=onehot_map, ambiguous=frozenset(ambiguous), levels_by_column=levels_by_column, ) def _leaf_paths( tree: DecisionTreeClassifier, ) -> dict[int, list[tuple[int, float, bool]]]: """Root-to-leaf decision paths: leaf id -> [(feature, threshold, went_right)].""" t = tree.tree_ paths: dict[int, list[tuple[int, float, bool]]] = {} stack: list[tuple[int, list[tuple[int, float, bool]]]] = [(0, [])] while stack: node, path = stack.pop() left = int(t.children_left[node]) if left == -1: # leaf paths[node] = path continue right = int(t.children_right[node]) feature = int(t.feature[node]) threshold = float(t.threshold[node]) stack.append((left, [*path, (feature, threshold, False)])) stack.append((right, [*path, (feature, threshold, True)])) return paths def _extract_leaf_rules( tree: DecisionTreeClassifier, feature_names: tuple[str, ...], visitors: pd.DataFrame, leaf_ids: list[int], ) -> dict[int, SegmentRule]: """One ``SegmentRule`` per leaf, decoded from the leaf's root path.""" paths = _leaf_paths(tree) resolution = _resolve_features(visitors) return { leaf_id: _path_to_rule(paths[leaf_id], feature_names, resolution) for leaf_id in leaf_ids } def _path_to_rule( path: list[tuple[int, float, bool]], feature_names: tuple[str, ...], resolution: _FeatureResolution, ) -> SegmentRule: """Fold one root-to-leaf path into a ``SegmentRule``. Numeric splits keep the tightest bound per direction (min of the ``lte`` uppers, max of the ``gt`` lowers). One-hot splits intersect per original categorical column: a right branch pins ``EqRule(column, level)`` (which wins over any ``InRule`` it satisfies); left branches accumulate excluded levels, yielding ``InRule`` over the remaining observed levels INCLUDING the ``get_dummies``-dropped first level. """ if not path: # root-only tree: the catch-all rule return SegmentRule(description="all visitors", clauses=()) lte_bounds: dict[str, float] = {} gt_bounds: dict[str, float] = {} eq_level: dict[str, str] = {} excluded_levels: dict[str, set[str]] = {} for feature_idx, threshold, went_right in path: name = feature_names[feature_idx] if name in resolution.numeric_columns: if went_right: # name > threshold — tightest lower bound is max previous = gt_bounds.get(name) gt_bounds[name] = ( threshold if previous is None else max(previous, threshold) ) else: # name <= threshold — tightest upper bound is min previous = lte_bounds.get(name) lte_bounds[name] = ( threshold if previous is None else min(previous, threshold) ) continue if name in resolution.ambiguous: raise RuntimeError( f"fit_policy_tree: split feature {name!r} resolves to more " "than one (column, level) pair of the visitors frame; the " "segment rule cannot be extracted unambiguously." ) if name not in resolution.onehot_map: raise RuntimeError( f"fit_policy_tree: split feature {name!r} matches neither a " "numeric visitor column nor a one-hot encoding of a " "categorical visitor column." ) column, level = resolution.onehot_map[name] if went_right: # one-hot == 1 -> column == level if column in eq_level and eq_level[column] != level: raise RuntimeError( f"fit_policy_tree: contradictory path requires " f"{column!r} == {eq_level[column]!r} and == {level!r}." ) eq_level[column] = level else: # one-hot == 0 -> column != level excluded_levels.setdefault(column, set()).add(level) clauses: list[RuleClause] = [] for name, bound in lte_bounds.items(): clauses.append(ComparisonRule(name, "lte", bound)) for name, bound in gt_bounds.items(): clauses.append(ComparisonRule(name, "gt", bound)) for column, level in eq_level.items(): clauses.append(EqRule(column, level)) for column, excluded in excluded_levels.items(): if column in eq_level: # EqRule wins over any InRule it satisfies. (A contradictory # EqRule-on-an-excluded-level path cannot produce a non-empty # sklearn leaf; the membership invariant would catch it.) continue allowed = tuple( sorted(set(resolution.levels_by_column[column]) - excluded) ) clauses.append(InRule(column, allowed)) # SegmentRule canonical-sorts clauses itself; pre-sort with the same key # so the description reads in the canonical clause order. ordered = sorted( clauses, key=lambda c: (c.feature, type(c).__name__, repr(c)) ) description = " AND ".join(_clause_text(clause) for clause in ordered) return SegmentRule(description=description, clauses=tuple(ordered)) _COMPARISON_TEXT = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="} def _clause_text(clause: RuleClause) -> str: """Short human-readable rendering of one rule clause.""" if isinstance(clause, EqRule): return f"{clause.feature} == {clause.value}" if isinstance(clause, InRule): return f"{clause.feature} in {{{', '.join(clause.values)}}}" if isinstance(clause, ComparisonRule): operator = _COMPARISON_TEXT[clause.operator] return f"{clause.feature} {operator} {clause.threshold:g}" raise TypeError( f"fit_policy_tree never produces {type(clause).__name__} clauses" ) # --------------------------------------------------------------------------- # Per-leaf recommendation and bootstrap stability # --------------------------------------------------------------------------- def _leaf_majority_label(tree: DecisionTreeClassifier, leaf_id: int) -> int: """The leaf's majority training label (the tree's class prediction).""" class_counts = tree.tree_.value[leaf_id, 0, :] return int(tree.classes_[int(np.argmax(class_counts))]) def _bootstrap_stability( X: np.ndarray, labels: np.ndarray, leaf_of_row: np.ndarray, leaf_ids: list[int], *, max_depth: int, min_segment_share: float, n_bootstrap: int, bootstrap_seed: int, ) -> dict[int, float]: """Bootstrap-replicability stability score per original leaf. For each of *n_bootstrap* row-resamples (with replacement): refit a tree with the SAME hyperparameters, route the ORIGINAL rows through it, and count a success for an original leaf when SOME bootstrap leaf's member set has Jaccard overlap >= 0.5 with it. The score is ``successes / n_bootstrap``. """ rng = np.random.default_rng(bootstrap_seed) n = X.shape[0] member_masks = {leaf_id: leaf_of_row == leaf_id for leaf_id in leaf_ids} successes = dict.fromkeys(leaf_ids, 0) for _ in range(n_bootstrap): idx = rng.integers(0, n, size=n) refit = DecisionTreeClassifier( max_depth=max_depth, min_weight_fraction_leaf=min_segment_share, random_state=0, ) refit.fit(X[idx], labels[idx]) refit_leaf_of_row = refit.apply(X) # original rows, bootstrap tree refit_masks = [ refit_leaf_of_row == leaf for leaf in np.unique(refit_leaf_of_row) ] for leaf_id, mask in member_masks.items(): if any(_jaccard(mask, m) >= 0.5 for m in refit_masks): successes[leaf_id] += 1 return { leaf_id: successes[leaf_id] / n_bootstrap for leaf_id in leaf_ids } def _jaccard(a: np.ndarray, b: np.ndarray) -> float: """Jaccard overlap of two boolean member masks (leaves are non-empty).""" intersection = int(np.logical_and(a, b).sum()) union = int(np.logical_or(a, b).sum()) return intersection / union