"""Per-round experiment snapshot: ``Experiment``, ``CellObservation``,
and the per-cell observed-performance computation.
An :class:`Experiment` is one completed round of a sequential experiment —
the fitted posterior, its analysis, the cells that shipped, what each cell
observed, and the recommendation engine's plan for the next round.
:func:`compute_cell_observations` produces the per-cell scoreboard from the
joint posterior conditioned on the actual cell-assignment indicators.
"""
from __future__ import annotations
import dataclasses
import html
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
from pytyche._internal.extraction import extract_fit_arrays
from pytyche.analysis._intervals import credible_interval
from pytyche.experiment.cells import Cell
if TYPE_CHECKING:
from pytyche.analysis import TruthComparison
from pytyche.bcf.config import (
BinaryBCFResult,
ContinuousBCFResult,
HurdleBCFResult,
)
from pytyche.contracts import AnalysisResult, ObservedExperimentData
from pytyche.experiment.recommendation import NextRoundPlan
__all__ = [
"CellObservation",
"Experiment",
"compute_cell_observations",
]
#: The cell id every round's reference cell carries — ``lift_vs_control``
#: is computed against it.
_CONTROL_CELL_ID = "control"
[docs]
@dataclasses.dataclass(frozen=True)
class CellObservation:
"""Observed per-cell performance for one round.
The scoreboard for the cells-as-blend story: when an operator ships a
hypothesis cell alongside the recommendation engine's cell, this is
where per-cell RPV and cross-cell lift appear. All intervals follow
the library-wide 80% credible convention.
Fields:
id: Matches the cell's id.
n_visitors: Visitors actually assigned to this cell this round.
observed_rpv: Empirical revenue per visitor over the cell's
members.
observed_rpv_ci: 80% credible interval on the cell's mean RPV
from the joint posterior conditioned on the member rows.
lift_vs_control: Point estimate of ``observed_rpv`` minus the
Control cell's; exactly ``0.0`` on the Control cell itself.
lift_vs_control_ci: 80% credible interval on the lift;
``(0.0, 0.0)`` on the Control cell.
lift_vs_other_cells: Point estimates of lift vs each other
non-Control cell, keyed by cell id (no self key, no control
key). Empty on the Control cell — Control IS the reference,
so per-cell lifts vs Control live on those cells'
``lift_vs_control``, not redundantly reversed here.
policy_summary: ``cell.policy.describe()``.
"""
id: str
n_visitors: int
observed_rpv: float
observed_rpv_ci: tuple[float, float]
lift_vs_control: float
lift_vs_control_ci: tuple[float, float]
lift_vs_other_cells: dict[str, float]
policy_summary: str
def _is_reference(self) -> bool:
"""Whether this is the Control cell (the lift reference)."""
return (
self.lift_vs_control == 0.0
and self.lift_vs_control_ci == (0.0, 0.0)
and not self.lift_vs_other_cells
)
def _row(self) -> str:
"""One scoreboard line; shared by this repr and Experiment's.
Point estimates are empirical means; the intervals are the joint
posterior's — two different sources, so the rendering labels the
interval (an uncalibrated model interval can sit away from the
empirical mean; unlabeled juxtaposition would read as a bug).
"""
text = (
f"{self.id:<12} n={self.n_visitors:>7,} "
f"RPV {self.observed_rpv:.4f} "
f"(model [{self.observed_rpv_ci[0]:.4f}, "
f"{self.observed_rpv_ci[1]:.4f}])"
)
if not self._is_reference():
text += (
f" lift vs control {self.lift_vs_control:+.4f} "
f"(model [{self.lift_vs_control_ci[0]:+.4f}, "
f"{self.lift_vs_control_ci[1]:+.4f}])"
)
return text
def _frame_row(self) -> dict[str, str]:
"""One pre-formatted scoreboard row; shared with Experiment's table.
Point and interval share a column as one non-breaking token —
split columns wrapped mid-bracket at narrow viewports.
"""
rpv = (
f"{self.observed_rpv:.4f} "
f"[{self.observed_rpv_ci[0]:.4f}, {self.observed_rpv_ci[1]:.4f}]"
)
lift = (
"—"
if self._is_reference()
else (
f"{self.lift_vs_control:+.4f} "
f"[{self.lift_vs_control_ci[0]:+.4f}, "
f"{self.lift_vs_control_ci[1]:+.4f}]"
)
)
return {
"cell": self.id,
"n": f"{self.n_visitors:,}",
"RPV (model 80% CI)": rpv.replace(" ", "\u00a0"),
"lift vs control": lift.replace(" ", "\u00a0"),
"policy": self.policy_summary,
}
def __repr__(self) -> str:
return f"CellObservation({self._row()})"
def _repr_html_(self) -> str:
return _frame_html([self._frame_row()])
def _frame_html(rows: list[dict[str, str]]) -> str:
"""Pre-formatted rows → pandas-styled HTML table.
Rendering through pandas keeps the ``dataframe`` CSS class, which
Jupyter, Colab, and the doc theme all style — never hand-rolled
``<table>`` markup (owner presentation ruling, 2026-06-12).
"""
return pd.DataFrame(rows).to_html(index=False, border=0)
[docs]
@dataclasses.dataclass(frozen=True)
class Experiment:
"""One round of a sequential experiment.
Composes the existing public single-shot types (the fitted posterior
and its :class:`~pytyche.contracts.AnalysisResult`) with the
sequential-specific context: cells shipped, per-cell observations,
the next-round recommendation, and the sim-only truth comparison.
Fields:
round_idx: Zero-based round index.
posterior: The round's fitted posterior (fit on cumulative data);
identical to ``analysis.posterior``.
analysis: The round's analysis summary.
cells_shipped: The cell structure that actually served traffic.
cell_observations: Per-cell observed performance, matching
``cells_shipped`` order.
next_recommendation: The engine's plan for the next round.
``None`` only on the engine's provisional in-flight view;
yielded experiments always carry a plan.
truth_comparison: Sim-mode truth metrics; ``None`` in real-data
mode.
"""
round_idx: int
posterior: HurdleBCFResult | ContinuousBCFResult | BinaryBCFResult
analysis: AnalysisResult
cells_shipped: list[Cell]
cell_observations: list[CellObservation]
next_recommendation: NextRoundPlan | None
truth_comparison: TruthComparison | None
@property
def observed(self) -> ObservedExperimentData | None:
"""The observed data the round's posterior was fit on.
Read-only alias of ``posterior.observed``.
"""
return self.posterior.observed
[docs]
def summary_one_line(self) -> str:
"""Deterministic one-line summary of the round (template-based)."""
recommendation = self.analysis.recommendation
return (
f"round {self.round_idx}: {self.analysis.metric}"
f" | {len(self.analysis.segments)} segment(s)"
f" | {recommendation.decision.name} {recommendation.treatment!r}"
f" (P(lift>0)={recommendation.probability_positive:.2f})"
)
def _truth_line(self) -> str | None:
if self.truth_comparison is None:
return None
tc = self.truth_comparison
return (
f"truth: cate_rmse={tc.cate_rmse:.4f}"
f" policy_accuracy={tc.policy_accuracy:.1%}"
f" oracle_gap={tc.oracle_gap_rpv:.4f}/visitor"
)
def _next_line(self) -> str | None:
if self.next_recommendation is None:
return None
plan = self.next_recommendation
size = (
f"{plan.n_visitors:,} visitors"
if plan.n_visitors is not None
else "no next round (schedule exhausted)"
)
return f"next: {len(plan.cells)} cell(s), {size}"
def __repr__(self) -> str:
header = self.summary_one_line() + (
" [calibrated]" if self.analysis.is_calibrated else ""
)
lines = [header, " cells:"]
lines.extend(f" {obs._row()}" for obs in self.cell_observations)
for extra in (self._truth_line(), self._next_line()):
if extra is not None:
lines.append(f" {extra}")
return "\n".join(lines)
def _repr_html_(self) -> str:
header = html.escape(
self.summary_one_line()
+ (" [calibrated]" if self.analysis.is_calibrated else "")
)
parts = [
f"<div><b>{header}</b>",
_frame_html([obs._frame_row() for obs in self.cell_observations]),
]
for extra in (self._truth_line(), self._next_line()):
if extra is not None:
parts.append(f"<p>{html.escape(extra)}</p>")
parts.append("</div>")
return "".join(parts)
[docs]
def compute_cell_observations(
posterior: HurdleBCFResult,
cells: list[Cell],
cell_assignment: np.ndarray,
) -> list[CellObservation]:
"""Per-cell observed performance from the joint posterior.
``cell_assignment`` holds one cell id per concatenated visitor row
(the extraction-adapter row order: ``observed.variants`` frames
concatenated in variant-list order). Point estimates are empirical
means of the members' ``revenue`` column; intervals condition the
joint posterior on the membership indicators — each visitor
contributes the level draws of its REALIZED arm (K = 2: the
``p0/p1`` × ``sev0/sev1`` channel selected by the row's treatment
indicator; K >= 3: the ``p_samples``/``sev_samples`` ``(n, S, K)``
slice at the realized arm), each cell's per-draw member mean is
reduced through the calibration-aware 80% interval path
(``pytyche.analysis._intervals.credible_interval``).
Args:
posterior: The round's hurdle posterior, carrying observed data
and retained per-visitor level draws.
cells: The round's cells; the output matches this order. Must
include the ``'control'`` cell (the lift reference).
cell_assignment: Array of cell ids, one per concatenated visitor
row.
Returns:
One :class:`CellObservation` per cell, in ``cells`` order.
Raises:
ValueError: When the posterior carries no observed data or no
per-visitor level draws (neither the K = 2 channel arrays nor
the K >= 3 ``p_samples`` / ``sev_samples``), when
``cell_assignment`` length mismatches the visitor rows, when
a cell has no assigned visitors, or when no cell has id
``'control'``.
"""
observed = posterior.observed
if observed is None:
raise ValueError(
"posterior carries no observed data (observed is None); cell "
"observations are computed over the concatenated visitor rows"
)
visitors = pd.concat(
[v.visitors for v in observed.variants], ignore_index=True
)
revenue = visitors["revenue"].to_numpy(dtype=np.float64)
assignment = np.asarray(cell_assignment)
if len(assignment) != len(visitors):
raise ValueError(
f"cell_assignment has {len(assignment)} entries for "
f"{len(visitors)} concatenated visitor rows"
)
level = _realized_level_draws(posterior, observed)
if all(cell.id != _CONTROL_CELL_ID for cell in cells):
raise ValueError(
f"cells must include a {_CONTROL_CELL_ID!r} cell — "
"lift_vs_control is computed against it"
)
member_masks: dict[str, np.ndarray] = {}
for cell in cells:
members = assignment == cell.id
if not members.any():
raise ValueError(f"cell {cell.id!r} has no assigned visitors")
member_masks[cell.id] = members
observed_rpv = {
cell.id: float(revenue[member_masks[cell.id]].mean()) for cell in cells
}
# Per-draw mean RPV over each cell's members, shape (n_draws,).
level_means = {
cell.id: level[member_masks[cell.id]].mean(axis=0) for cell in cells
}
observations: list[CellObservation] = []
for cell in cells:
cell_rpv = observed_rpv[cell.id]
cell_level = level_means[cell.id]
rpv_ci = credible_interval(cell_level, posterior.calibration)
if cell.id == _CONTROL_CELL_ID:
lift = 0.0
lift_ci = (0.0, 0.0)
lift_others: dict[str, float] = {}
else:
lift = float(cell_rpv - observed_rpv[_CONTROL_CELL_ID])
lift_draws = cell_level - level_means[_CONTROL_CELL_ID]
lift_ci = credible_interval(lift_draws, posterior.calibration)
lift_others = {
other.id: float(cell_rpv - observed_rpv[other.id])
for other in cells
if other.id not in (cell.id, _CONTROL_CELL_ID)
}
observations.append(
CellObservation(
id=cell.id,
n_visitors=int(member_masks[cell.id].sum()),
observed_rpv=cell_rpv,
observed_rpv_ci=rpv_ci,
lift_vs_control=lift,
lift_vs_control_ci=lift_ci,
lift_vs_other_cells=lift_others,
policy_summary=cell.policy.describe(),
)
)
return observations
def _realized_level_draws(
posterior: HurdleBCFResult, observed: ObservedExperimentData
) -> np.ndarray:
"""Per-visitor RPV level draws at each row's REALIZED arm, ``(n, S)``.
A cell's RPV posterior conditions on what its members actually
received, so each row contributes the level draws of its realized
treatment: at K >= 3 the ``(n, S, K)`` ``p_samples`` × ``sev_samples``
slice at the row's arm; at K = 2 the ``p0/p1`` × ``sev0/sev1``
channel selected by the row's treatment indicator.
"""
arms = extract_fit_arrays(observed).Z.astype(int)
if posterior.p_samples is not None and posterior.sev_samples is not None:
level_by_arm = np.asarray(
posterior.p_samples, dtype=np.float64
) * np.asarray(posterior.sev_samples, dtype=np.float64)
if level_by_arm.ndim != 3:
raise ValueError(
"p_samples / sev_samples must be (n_rows, n_draws, K); got "
f"shape {level_by_arm.shape}"
)
idx = arms[:, None, None]
return np.take_along_axis(level_by_arm, idx, axis=2)[:, :, 0]
channels = (
posterior.p0_samples,
posterior.p1_samples,
posterior.sev0_samples,
posterior.sev1_samples,
)
if any(channel is None for channel in channels):
raise ValueError(
"posterior carries no per-visitor level draws (neither the "
"K = 2 channel arrays p0/p1_samples × sev0/sev1_samples nor the "
"K >= 3 p_samples / sev_samples); cell CIs condition the joint "
"posterior on the cell-assignment indicators and require them"
)
p0, p1, sev0, sev1 = (
np.asarray(channel, dtype=np.float64) for channel in channels
)
return np.where(arms[:, None] == 0, p0 * sev0, p1 * sev1)