Source code for pytyche.experiment.cells

"""Cell (assignment cohort) + Policy protocol with shipped routing variants.

A round's traffic is split across cells by weight; each cell routes its
visitors through a ``Policy``. ``assign`` returns the treatment's integer
index in the experiment's treatments universe — shipped policies are
constructed from treatment NAMES and resolve indices via the optional
``treatments`` universe (when omitted, ``UniformPolicy`` treats ``over``
as the universe and ``TreePolicy`` derives it from allocation-map key
order, which the recommendation engine emits in universe order). Custom
policies implement the two protocol members and handle their own mapping.
"""

from __future__ import annotations

import dataclasses
from collections.abc import Sequence
from typing import TYPE_CHECKING, Protocol, runtime_checkable

import numpy as np

if TYPE_CHECKING:
    from sklearn.tree import DecisionTreeClassifier

__all__ = [
    "BaselinePolicy",
    "Cell",
    "Policy",
    "TreePolicy",
    "UniformPolicy",
    "validate_cell_weights",
]

#: Cell-weight sums must hit 1.0 within this tolerance (spec-pinned).
WEIGHT_SUM_TOLERANCE = 1e-6


[docs] @runtime_checkable class Policy(Protocol): """Per-visitor treatment routing rule carried by a :class:`Cell`.""" def assign(self, features: np.ndarray, rng: np.random.Generator) -> int: ... def describe(self) -> str: ...
[docs] @dataclasses.dataclass(frozen=True) class Cell: """Assignment cohort within a single round. ``weight`` is this cell's share of the round's traffic; weights across a round's cells must sum to 1.0 (validated whenever a cell list is consumed — see :func:`validate_cell_weights`). """ id: str policy: Policy weight: float def __post_init__(self) -> None: if not (0.0 <= self.weight <= 1.0): raise ValueError( f"Cell weight must be in [0.0, 1.0], got {self.weight} " f"for cell {self.id!r}" ) def __repr__(self) -> str: return f"Cell({self.id!r}, {self.policy.describe()}, weight={self.weight:.3f})"
[docs] def validate_cell_weights(cells: Sequence[Cell]) -> None: """Raise ``ValueError`` unless the cells' weights sum to 1.0 (±1e-6).""" total = sum(cell.weight for cell in cells) if abs(total - 1.0) > WEIGHT_SUM_TOLERANCE: raise ValueError( f"Cell weights must sum to 1.0 (±{WEIGHT_SUM_TOLERANCE}); got " f"sum={total:.6f} across {[cell.id for cell in cells]}" )
[docs] @dataclasses.dataclass(frozen=True) class BaselinePolicy: """Always routes to control (index 0). Ignores features.""" def assign(self, features: np.ndarray, rng: np.random.Generator) -> int: return 0 def describe(self) -> str: return "baseline: always control"
[docs] @dataclasses.dataclass(frozen=True) class UniformPolicy: """Uniform-random over a set of treatment names. ``treatments`` is the experiment's full universe for index resolution; when ``None``, ``over`` itself is the universe (the common case — the engine's Explore cell and the round-1 default both pass the full active list as ``over``). """ over: Sequence[str] treatments: Sequence[str] | None = None def __post_init__(self) -> None: if not self.over: raise ValueError("UniformPolicy requires a non-empty `over` list") object.__setattr__(self, "over", tuple(self.over)) if self.treatments is not None: object.__setattr__(self, "treatments", tuple(self.treatments)) missing = [name for name in self.over if name not in self.treatments] if missing: raise ValueError( f"`over` names {missing} are not in the treatments " f"universe {list(self.treatments)}" ) def assign(self, features: np.ndarray, rng: np.random.Generator) -> int: universe = self.treatments if self.treatments is not None else self.over name = self.over[int(rng.integers(0, len(self.over)))] return universe.index(name) # type: ignore[union-attr] def describe(self) -> str: return f"uniform over {list(self.over)}"
[docs] @dataclasses.dataclass(frozen=True) class TreePolicy: """Routes features through a fitted decision tree to a per-leaf allocation. ``allocation_map`` maps sklearn leaf ids to per-treatment-name weight dicts (each summing to 1.0). ``treatments`` is the index-resolution universe; when ``None`` it is derived from allocation-map key order (first-seen across leaves — the order the recommendation engine emits). """ tree: DecisionTreeClassifier allocation_map: dict[int, dict[str, float]] treatments: Sequence[str] | None = None def __post_init__(self) -> None: if not self.allocation_map: raise ValueError("TreePolicy requires a non-empty allocation_map") for leaf_id, weights in self.allocation_map.items(): total = sum(weights.values()) if abs(total - 1.0) > WEIGHT_SUM_TOLERANCE: raise ValueError( f"allocation_map[{leaf_id}] weights must sum to 1.0 " f"(±{WEIGHT_SUM_TOLERANCE}); got sum={total:.6f}" ) if self.treatments is not None: object.__setattr__(self, "treatments", tuple(self.treatments)) universe = self._universe() for leaf_id, weights in self.allocation_map.items(): unknown = [name for name in weights if name not in universe] if unknown: raise ValueError( f"allocation_map[{leaf_id}] names {unknown} are not in " f"the treatments universe {list(universe)}" ) def _universe(self) -> tuple[str, ...]: if self.treatments is not None: return tuple(self.treatments) seen: dict[str, None] = {} for weights in self.allocation_map.values(): for name in weights: seen.setdefault(name) return tuple(seen) def assign(self, features: np.ndarray, rng: np.random.Generator) -> int: leaf_id = int(self.tree.apply(features.reshape(1, -1))[0]) weights = self.allocation_map[leaf_id] names = list(weights) probabilities = np.array([weights[name] for name in names], dtype=float) chosen = names[int(rng.choice(len(names), p=probabilities / probabilities.sum()))] return self._universe().index(chosen) def describe(self) -> str: return ( f"policy tree routing over {len(self.allocation_map)} segments " f"with per-leaf Thompson allocation" )