Source code for pytyche.experiment.experiment

"""Per-round experiment snapshot: ``Experiment``, ``CellObservation``,
and the per-cell observed-performance computation.

An :class:`Experiment` is one completed round of a sequential experiment —
the fitted posterior, its analysis, the cells that shipped, what each cell
observed, and the recommendation engine's plan for the next round.
:func:`compute_cell_observations` produces the per-cell scoreboard from the
joint posterior conditioned on the actual cell-assignment indicators.
"""

from __future__ import annotations

import dataclasses
import html
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd

from pytyche._internal.extraction import extract_fit_arrays
from pytyche.analysis._intervals import credible_interval
from pytyche.experiment.cells import Cell

if TYPE_CHECKING:
    from pytyche.analysis import TruthComparison
    from pytyche.bcf.config import (
        BinaryBCFResult,
        ContinuousBCFResult,
        HurdleBCFResult,
    )
    from pytyche.contracts import AnalysisResult, ObservedExperimentData
    from pytyche.experiment.recommendation import NextRoundPlan

__all__ = [
    "CellObservation",
    "Experiment",
    "compute_cell_observations",
]

#: The cell id every round's reference cell carries — ``lift_vs_control``
#: is computed against it.
_CONTROL_CELL_ID = "control"


[docs] @dataclasses.dataclass(frozen=True) class CellObservation: """Observed per-cell performance for one round. The scoreboard for the cells-as-blend story: when an operator ships a hypothesis cell alongside the recommendation engine's cell, this is where per-cell RPV and cross-cell lift appear. All intervals follow the library-wide 80% credible convention. Fields: id: Matches the cell's id. n_visitors: Visitors actually assigned to this cell this round. observed_rpv: Empirical revenue per visitor over the cell's members. observed_rpv_ci: 80% credible interval on the cell's mean RPV from the joint posterior conditioned on the member rows. lift_vs_control: Point estimate of ``observed_rpv`` minus the Control cell's; exactly ``0.0`` on the Control cell itself. lift_vs_control_ci: 80% credible interval on the lift; ``(0.0, 0.0)`` on the Control cell. lift_vs_other_cells: Point estimates of lift vs each other non-Control cell, keyed by cell id (no self key, no control key). Empty on the Control cell — Control IS the reference, so per-cell lifts vs Control live on those cells' ``lift_vs_control``, not redundantly reversed here. policy_summary: ``cell.policy.describe()``. """ id: str n_visitors: int observed_rpv: float observed_rpv_ci: tuple[float, float] lift_vs_control: float lift_vs_control_ci: tuple[float, float] lift_vs_other_cells: dict[str, float] policy_summary: str def _is_reference(self) -> bool: """Whether this is the Control cell (the lift reference).""" return ( self.lift_vs_control == 0.0 and self.lift_vs_control_ci == (0.0, 0.0) and not self.lift_vs_other_cells ) def _row(self) -> str: """One scoreboard line; shared by this repr and Experiment's. Point estimates are empirical means; the intervals are the joint posterior's — two different sources, so the rendering labels the interval (an uncalibrated model interval can sit away from the empirical mean; unlabeled juxtaposition would read as a bug). """ text = ( f"{self.id:<12} n={self.n_visitors:>7,} " f"RPV {self.observed_rpv:.4f} " f"(model [{self.observed_rpv_ci[0]:.4f}, " f"{self.observed_rpv_ci[1]:.4f}])" ) if not self._is_reference(): text += ( f" lift vs control {self.lift_vs_control:+.4f} " f"(model [{self.lift_vs_control_ci[0]:+.4f}, " f"{self.lift_vs_control_ci[1]:+.4f}])" ) return text def _frame_row(self) -> dict[str, str]: """One pre-formatted scoreboard row; shared with Experiment's table. Point and interval share a column as one non-breaking token — split columns wrapped mid-bracket at narrow viewports. """ rpv = ( f"{self.observed_rpv:.4f} " f"[{self.observed_rpv_ci[0]:.4f}, {self.observed_rpv_ci[1]:.4f}]" ) lift = ( "—" if self._is_reference() else ( f"{self.lift_vs_control:+.4f} " f"[{self.lift_vs_control_ci[0]:+.4f}, " f"{self.lift_vs_control_ci[1]:+.4f}]" ) ) return { "cell": self.id, "n": f"{self.n_visitors:,}", "RPV (model 80% CI)": rpv.replace(" ", "\u00a0"), "lift vs control": lift.replace(" ", "\u00a0"), "policy": self.policy_summary, } def __repr__(self) -> str: return f"CellObservation({self._row()})" def _repr_html_(self) -> str: return _frame_html([self._frame_row()])
def _frame_html(rows: list[dict[str, str]]) -> str: """Pre-formatted rows → pandas-styled HTML table. Rendering through pandas keeps the ``dataframe`` CSS class, which Jupyter, Colab, and the doc theme all style — never hand-rolled ``<table>`` markup (owner presentation ruling, 2026-06-12). """ return pd.DataFrame(rows).to_html(index=False, border=0)
[docs] @dataclasses.dataclass(frozen=True) class Experiment: """One round of a sequential experiment. Composes the existing public single-shot types (the fitted posterior and its :class:`~pytyche.contracts.AnalysisResult`) with the sequential-specific context: cells shipped, per-cell observations, the next-round recommendation, and the sim-only truth comparison. Fields: round_idx: Zero-based round index. posterior: The round's fitted posterior (fit on cumulative data); identical to ``analysis.posterior``. analysis: The round's analysis summary. cells_shipped: The cell structure that actually served traffic. cell_observations: Per-cell observed performance, matching ``cells_shipped`` order. next_recommendation: The engine's plan for the next round. ``None`` only on the engine's provisional in-flight view; yielded experiments always carry a plan. truth_comparison: Sim-mode truth metrics; ``None`` in real-data mode. """ round_idx: int posterior: HurdleBCFResult | ContinuousBCFResult | BinaryBCFResult analysis: AnalysisResult cells_shipped: list[Cell] cell_observations: list[CellObservation] next_recommendation: NextRoundPlan | None truth_comparison: TruthComparison | None @property def observed(self) -> ObservedExperimentData | None: """The observed data the round's posterior was fit on. Read-only alias of ``posterior.observed``. """ return self.posterior.observed
[docs] def summary_one_line(self) -> str: """Deterministic one-line summary of the round (template-based).""" recommendation = self.analysis.recommendation return ( f"round {self.round_idx}: {self.analysis.metric}" f" | {len(self.analysis.segments)} segment(s)" f" | {recommendation.decision.name} {recommendation.treatment!r}" f" (P(lift>0)={recommendation.probability_positive:.2f})" )
def _truth_line(self) -> str | None: if self.truth_comparison is None: return None tc = self.truth_comparison return ( f"truth: cate_rmse={tc.cate_rmse:.4f}" f" policy_accuracy={tc.policy_accuracy:.1%}" f" oracle_gap={tc.oracle_gap_rpv:.4f}/visitor" ) def _next_line(self) -> str | None: if self.next_recommendation is None: return None plan = self.next_recommendation size = ( f"{plan.n_visitors:,} visitors" if plan.n_visitors is not None else "no next round (schedule exhausted)" ) return f"next: {len(plan.cells)} cell(s), {size}" def __repr__(self) -> str: header = self.summary_one_line() + ( " [calibrated]" if self.analysis.is_calibrated else "" ) lines = [header, " cells:"] lines.extend(f" {obs._row()}" for obs in self.cell_observations) for extra in (self._truth_line(), self._next_line()): if extra is not None: lines.append(f" {extra}") return "\n".join(lines) def _repr_html_(self) -> str: header = html.escape( self.summary_one_line() + (" [calibrated]" if self.analysis.is_calibrated else "") ) parts = [ f"<div><b>{header}</b>", _frame_html([obs._frame_row() for obs in self.cell_observations]), ] for extra in (self._truth_line(), self._next_line()): if extra is not None: parts.append(f"<p>{html.escape(extra)}</p>") parts.append("</div>") return "".join(parts)
[docs] def compute_cell_observations( posterior: HurdleBCFResult, cells: list[Cell], cell_assignment: np.ndarray, ) -> list[CellObservation]: """Per-cell observed performance from the joint posterior. ``cell_assignment`` holds one cell id per concatenated visitor row (the extraction-adapter row order: ``observed.variants`` frames concatenated in variant-list order). Point estimates are empirical means of the members' ``revenue`` column; intervals condition the joint posterior on the membership indicators — each visitor contributes the level draws of its REALIZED arm (K = 2: the ``p0/p1`` × ``sev0/sev1`` channel selected by the row's treatment indicator; K >= 3: the ``p_samples``/``sev_samples`` ``(n, S, K)`` slice at the realized arm), each cell's per-draw member mean is reduced through the calibration-aware 80% interval path (``pytyche.analysis._intervals.credible_interval``). Args: posterior: The round's hurdle posterior, carrying observed data and retained per-visitor level draws. cells: The round's cells; the output matches this order. Must include the ``'control'`` cell (the lift reference). cell_assignment: Array of cell ids, one per concatenated visitor row. Returns: One :class:`CellObservation` per cell, in ``cells`` order. Raises: ValueError: When the posterior carries no observed data or no per-visitor level draws (neither the K = 2 channel arrays nor the K >= 3 ``p_samples`` / ``sev_samples``), when ``cell_assignment`` length mismatches the visitor rows, when a cell has no assigned visitors, or when no cell has id ``'control'``. """ observed = posterior.observed if observed is None: raise ValueError( "posterior carries no observed data (observed is None); cell " "observations are computed over the concatenated visitor rows" ) visitors = pd.concat( [v.visitors for v in observed.variants], ignore_index=True ) revenue = visitors["revenue"].to_numpy(dtype=np.float64) assignment = np.asarray(cell_assignment) if len(assignment) != len(visitors): raise ValueError( f"cell_assignment has {len(assignment)} entries for " f"{len(visitors)} concatenated visitor rows" ) level = _realized_level_draws(posterior, observed) if all(cell.id != _CONTROL_CELL_ID for cell in cells): raise ValueError( f"cells must include a {_CONTROL_CELL_ID!r} cell — " "lift_vs_control is computed against it" ) member_masks: dict[str, np.ndarray] = {} for cell in cells: members = assignment == cell.id if not members.any(): raise ValueError(f"cell {cell.id!r} has no assigned visitors") member_masks[cell.id] = members observed_rpv = { cell.id: float(revenue[member_masks[cell.id]].mean()) for cell in cells } # Per-draw mean RPV over each cell's members, shape (n_draws,). level_means = { cell.id: level[member_masks[cell.id]].mean(axis=0) for cell in cells } observations: list[CellObservation] = [] for cell in cells: cell_rpv = observed_rpv[cell.id] cell_level = level_means[cell.id] rpv_ci = credible_interval(cell_level, posterior.calibration) if cell.id == _CONTROL_CELL_ID: lift = 0.0 lift_ci = (0.0, 0.0) lift_others: dict[str, float] = {} else: lift = float(cell_rpv - observed_rpv[_CONTROL_CELL_ID]) lift_draws = cell_level - level_means[_CONTROL_CELL_ID] lift_ci = credible_interval(lift_draws, posterior.calibration) lift_others = { other.id: float(cell_rpv - observed_rpv[other.id]) for other in cells if other.id not in (cell.id, _CONTROL_CELL_ID) } observations.append( CellObservation( id=cell.id, n_visitors=int(member_masks[cell.id].sum()), observed_rpv=cell_rpv, observed_rpv_ci=rpv_ci, lift_vs_control=lift, lift_vs_control_ci=lift_ci, lift_vs_other_cells=lift_others, policy_summary=cell.policy.describe(), ) ) return observations
def _realized_level_draws( posterior: HurdleBCFResult, observed: ObservedExperimentData ) -> np.ndarray: """Per-visitor RPV level draws at each row's REALIZED arm, ``(n, S)``. A cell's RPV posterior conditions on what its members actually received, so each row contributes the level draws of its realized treatment: at K >= 3 the ``(n, S, K)`` ``p_samples`` × ``sev_samples`` slice at the row's arm; at K = 2 the ``p0/p1`` × ``sev0/sev1`` channel selected by the row's treatment indicator. """ arms = extract_fit_arrays(observed).Z.astype(int) if posterior.p_samples is not None and posterior.sev_samples is not None: level_by_arm = np.asarray( posterior.p_samples, dtype=np.float64 ) * np.asarray(posterior.sev_samples, dtype=np.float64) if level_by_arm.ndim != 3: raise ValueError( "p_samples / sev_samples must be (n_rows, n_draws, K); got " f"shape {level_by_arm.shape}" ) idx = arms[:, None, None] return np.take_along_axis(level_by_arm, idx, axis=2)[:, :, 0] channels = ( posterior.p0_samples, posterior.p1_samples, posterior.sev0_samples, posterior.sev1_samples, ) if any(channel is None for channel in channels): raise ValueError( "posterior carries no per-visitor level draws (neither the " "K = 2 channel arrays p0/p1_samples × sev0/sev1_samples nor the " "K >= 3 p_samples / sev_samples); cell CIs condition the joint " "posterior on the cell-assignment indicators and require them" ) p0, p1, sev0, sev1 = ( np.asarray(channel, dtype=np.float64) for channel in channels ) return np.where(arms[:, None] == 0, p0 * sev0, p1 * sev1)