Source code for pytyche.experiment.cells
"""Cell (assignment cohort) + Policy protocol with shipped routing variants.
A round's traffic is split across cells by weight; each cell routes its
visitors through a ``Policy``. ``assign`` returns the treatment's integer
index in the experiment's treatments universe — shipped policies are
constructed from treatment NAMES and resolve indices via the optional
``treatments`` universe (when omitted, ``UniformPolicy`` treats ``over``
as the universe and ``TreePolicy`` derives it from allocation-map key
order, which the recommendation engine emits in universe order). Custom
policies implement the two protocol members and handle their own mapping.
"""
from __future__ import annotations
import dataclasses
from collections.abc import Sequence
from typing import TYPE_CHECKING, Protocol, runtime_checkable
import numpy as np
if TYPE_CHECKING:
from sklearn.tree import DecisionTreeClassifier
__all__ = [
"BaselinePolicy",
"Cell",
"Policy",
"TreePolicy",
"UniformPolicy",
"validate_cell_weights",
]
#: Cell-weight sums must hit 1.0 within this tolerance (spec-pinned).
WEIGHT_SUM_TOLERANCE = 1e-6
[docs]
@runtime_checkable
class Policy(Protocol):
"""Per-visitor treatment routing rule carried by a :class:`Cell`."""
def assign(self, features: np.ndarray, rng: np.random.Generator) -> int: ...
def describe(self) -> str: ...
[docs]
@dataclasses.dataclass(frozen=True)
class Cell:
"""Assignment cohort within a single round.
``weight`` is this cell's share of the round's traffic; weights across
a round's cells must sum to 1.0 (validated whenever a cell list is
consumed — see :func:`validate_cell_weights`).
"""
id: str
policy: Policy
weight: float
def __post_init__(self) -> None:
if not (0.0 <= self.weight <= 1.0):
raise ValueError(
f"Cell weight must be in [0.0, 1.0], got {self.weight} "
f"for cell {self.id!r}"
)
def __repr__(self) -> str:
return f"Cell({self.id!r}, {self.policy.describe()}, weight={self.weight:.3f})"
[docs]
def validate_cell_weights(cells: Sequence[Cell]) -> None:
"""Raise ``ValueError`` unless the cells' weights sum to 1.0 (±1e-6)."""
total = sum(cell.weight for cell in cells)
if abs(total - 1.0) > WEIGHT_SUM_TOLERANCE:
raise ValueError(
f"Cell weights must sum to 1.0 (±{WEIGHT_SUM_TOLERANCE}); got "
f"sum={total:.6f} across {[cell.id for cell in cells]}"
)
[docs]
@dataclasses.dataclass(frozen=True)
class BaselinePolicy:
"""Always routes to control (index 0). Ignores features."""
def assign(self, features: np.ndarray, rng: np.random.Generator) -> int:
return 0
def describe(self) -> str:
return "baseline: always control"
[docs]
@dataclasses.dataclass(frozen=True)
class TreePolicy:
"""Routes features through a fitted decision tree to a per-leaf allocation.
``allocation_map`` maps sklearn leaf ids to per-treatment-name weight
dicts (each summing to 1.0). ``treatments`` is the index-resolution
universe; when ``None`` it is derived from allocation-map key order
(first-seen across leaves — the order the recommendation engine emits).
"""
tree: DecisionTreeClassifier
allocation_map: dict[int, dict[str, float]]
treatments: Sequence[str] | None = None
def __post_init__(self) -> None:
if not self.allocation_map:
raise ValueError("TreePolicy requires a non-empty allocation_map")
for leaf_id, weights in self.allocation_map.items():
total = sum(weights.values())
if abs(total - 1.0) > WEIGHT_SUM_TOLERANCE:
raise ValueError(
f"allocation_map[{leaf_id}] weights must sum to 1.0 "
f"(±{WEIGHT_SUM_TOLERANCE}); got sum={total:.6f}"
)
if self.treatments is not None:
object.__setattr__(self, "treatments", tuple(self.treatments))
universe = self._universe()
for leaf_id, weights in self.allocation_map.items():
unknown = [name for name in weights if name not in universe]
if unknown:
raise ValueError(
f"allocation_map[{leaf_id}] names {unknown} are not in "
f"the treatments universe {list(universe)}"
)
def _universe(self) -> tuple[str, ...]:
if self.treatments is not None:
return tuple(self.treatments)
seen: dict[str, None] = {}
for weights in self.allocation_map.values():
for name in weights:
seen.setdefault(name)
return tuple(seen)
def assign(self, features: np.ndarray, rng: np.random.Generator) -> int:
leaf_id = int(self.tree.apply(features.reshape(1, -1))[0])
weights = self.allocation_map[leaf_id]
names = list(weights)
probabilities = np.array([weights[name] for name in names], dtype=float)
chosen = names[int(rng.choice(len(names), p=probabilities / probabilities.sum()))]
return self._universe().index(chosen)
def describe(self) -> str:
return (
f"policy tree routing over {len(self.allocation_map)} segments "
f"with per-leaf Thompson allocation"
)