"""v2 typed contracts for the pytyche analysis pipeline.
This module IS the API reference. Each frozen dataclass defines a contract
between pipeline stages, with docstrings documenting fields, invariants,
and the containment level it operates at.
Containment chain::
Visitor → Variant → Experiment → Program
Orthogonal axis — Segment::
Visitor → Variant → Experiment → Program
↑
Segment ─┘ (cross-cutting rule/group lens)
Key boundaries enforced by types:
- **Observed ↔ truth**: ``ObservedExperimentData`` has NO truth field.
Analysis code structurally cannot peek at ground truth.
- **Analysis ↔ diagnostics**: ``AnalysisResult`` carries core results.
``DiagnosticsBundle`` carries PyMC internals, returned separately.
- **Discovery ↔ internals**: ``DiscoveredSegment`` exposes segment
outputs. The fitted estimator object is never exposed downstream.
"""
from __future__ import annotations
import dataclasses
import enum
import math
from datetime import datetime
from typing import TYPE_CHECKING, ClassVar, Literal, NamedTuple
import numpy as np
import pandas as pd
if TYPE_CHECKING:
from arviz import InferenceData
# bcf.config imports contracts at runtime — the reverse import MUST stay
# type-checking-only to avoid a runtime cycle.
from pytyche.bcf.config import (
BinaryBCFResult,
ContinuousBCFResult,
HurdleBCFResult,
)
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
[docs]
class MetricFamily(enum.StrEnum):
"""Abstract metric family taxonomy.
Determines model structure and decomposition availability.
"""
BINARY = "binary"
HURDLE_REAL = "hurdle_real"
REAL = "real"
[docs]
class Decision(enum.StrEnum):
"""Recommendation decision outcome.
``SHIP``: deploy the treatment.
``CONTINUE``: keep collecting data.
``STOP``: abandon the treatment (harmful or futile).
"""
SHIP = "ship"
CONTINUE = "continue"
STOP = "stop"
[docs]
class ClaimLevel(enum.StrEnum):
"""What the operator can claim from the analysis.
Describes the evidentiary strength, not the splitting mechanism.
Stable across estimator changes (e.g. BCF makes splitting optional).
``EXPLORATORY``: data-driven discovery, not pre-registered.
``HONEST_ESTIMATE``: sample-split or honest-forest estimates.
``CONFIRMED``: replicated in a hold-out experiment.
"""
EXPLORATORY = "exploratory"
HONEST_ESTIMATE = "honest_estimate"
CONFIRMED = "confirmed"
# ---------------------------------------------------------------------------
# Visitor-level contract (Level 1)
# ---------------------------------------------------------------------------
#: Required columns and their numpy dtypes for visitor DataFrames.
#:
#: Both generators and production loaders MUST produce DataFrames with at
#: least these columns. Additional feature columns (segment assignments,
#: device, country, etc.) are permitted and used by HTE discovery.
#:
#: Invariants:
#: - One row per visitor (unique ``visitor_id``).
#: - ``revenue >= 0``.
#: - Generator expectation: ``converted`` implies ``revenue > 0``.
#: Production data may have ``converted=True, revenue=0`` (free trials,
#: coupons) — analysis handles both.
VISITOR_SCHEMA: dict[str, str] = {
"visitor_id": "object", # string identifier
"experiment_id": "object", # string identifier
"variant": "object", # string variant name
"converted": "bool", # did the visitor convert?
"revenue": "float64", # total revenue (0.0 if not converted)
"orders_count": "int64", # number of orders placed
"sessions_count": "int64", # number of sessions observed
}
#: Reserved per-visitor column name for binary-arm propensity scores.
#:
#: At K=2 this column carries ``P(Z=1 | x)`` — the probability of assignment
#: to the treatment variant given visitor covariates. At K≥3 the multi-arm
#: equivalents are ``propensity_1 … propensity_{K-1}`` (``P(Z=k | x)``),
#: following the ``RESERVED_PROPENSITY_PREFIX`` pattern.
#:
#: These columns are NEVER features — the fit-boundary adapter excludes them
#: from the feature matrix ``X`` at every K. Use ``is_reserved_propensity_column``
#: to test any column name against the full reserved set.
RESERVED_PROPENSITY_COLUMN: str = "propensity"
#: Prefix for multi-arm propensity columns (``propensity_0``, ``propensity_1``, …).
#:
#: Any column whose name is exactly ``RESERVED_PROPENSITY_COLUMN`` or matches
#: ``RESERVED_PROPENSITY_PREFIX + <digits>`` is reserved and excluded from the
#: feature matrix. The ``propensity_0`` form is included as a deliberate
#: fail-safe — it is not a standard K≥3 propensity column, but admitting it
#: prevents accidental leakage of propensity-like columns into HTE discovery.
RESERVED_PROPENSITY_PREFIX: str = "propensity_"
#: Reserved per-visitor column name for sequential-experiment cell membership.
#:
#: Carries the id of the cell (Control / Explore / Optimized / operator
#: hypothesis cell) that allocated the visitor. Recorded at data-generation
#: time — membership is not derivable from the treatment received, since
#: e.g. an Explore-cell visitor can draw control. Never a feature: the
#: fit-boundary adapter excludes it from ``X``; the single-shot fit path
#: otherwise ignores it. Consumed by ``pt.sequential_experiment`` to
#: compute per-cell observations.
RESERVED_CELL_COLUMN: str = "cell"
[docs]
def is_reserved_propensity_column(name: str) -> bool:
"""Return True if *name* is a reserved propensity column name.
Reserved names:
- exactly ``"propensity"`` (K=2: ``P(Z=1 | x)``)
- ``"propensity_<digits>"`` (K≥3: ``P(Z=k | x)``; also matches
``propensity_0`` as a deliberate fail-safe superset)
Any column matching this predicate is excluded from the feature matrix
by the fit-boundary extraction adapter.
Args:
name: Column name to test.
Returns:
True when the column is reserved; False otherwise.
"""
if name == RESERVED_PROPENSITY_COLUMN:
return True
if name.startswith(RESERVED_PROPENSITY_PREFIX):
suffix = name[len(RESERVED_PROPENSITY_PREFIX):]
return suffix.isdigit()
return False
# ---------------------------------------------------------------------------
# Segment contracts (cross-cutting axis)
# ---------------------------------------------------------------------------
# --- Rule algebra ---
# Discriminated union: each operator gets its own dataclass.
# Invalid combos (e.g. "eq" with a float threshold) are impossible to
# construct.
[docs]
@dataclasses.dataclass(frozen=True)
class EqRule:
"""Categorical equality: ``feature == value``.
Example: ``EqRule("lifecycle_stage", "new_visitor")``.
"""
feature: str
value: str
[docs]
@dataclasses.dataclass(frozen=True)
class InRule:
"""Categorical set membership: ``feature in values``.
Example: ``InRule("device", ("mobile", "tablet"))``.
"""
feature: str
values: tuple[str, ...]
[docs]
@dataclasses.dataclass(frozen=True)
class ComparisonRule:
"""Numeric threshold: ``feature <op> threshold``.
Example: ``ComparisonRule("age", "gt", 35.0)`` means ``age > 35``.
"""
feature: str
operator: Literal["gt", "gte", "lt", "lte"]
threshold: float
[docs]
@dataclasses.dataclass(frozen=True)
class BetweenRule:
"""Numeric range: ``low <= feature <= high``.
Inclusive on both ends.
Example: ``BetweenRule("spend", 10.0, 100.0)``.
"""
feature: str
low: float
high: float
def __post_init__(self) -> None:
if self.low > self.high:
raise ValueError(
f"BetweenRule: low ({self.low}) must be <= high ({self.high})"
)
#: Union of all rule clause types. Clauses within a ``SegmentRule``
#: are AND-combined.
RuleClause = EqRule | InRule | ComparisonRule | BetweenRule
[docs]
@dataclasses.dataclass(frozen=True)
class SegmentRule:
"""Rule defining a group of visitors.
Shared across all segment contexts (manual, discovered, registered).
Clauses are AND-combined. Canonical sort by feature name ensures
deterministic equality, hashing, and serialization regardless of
input order.
``clauses=()`` is the catch-all rule matching every visitor
(``apply_rule``'s AND-fold over zero clauses is vacuously all-True),
produced by ``fit_policy_tree`` for a root-only (single-leaf) tree.
Level: cross-cutting (applied over visitor sets).
"""
description: str
clauses: tuple[RuleClause, ...]
def __post_init__(self) -> None:
# Enforce canonical sort: primary by feature name, secondary by
# clause type name for deterministic ordering of same-feature clauses
# (e.g. ComparisonRule("age", "gt", 18) before EqRule("age", "senior")).
def sort_key(c):
return (c.feature, type(c).__name__, repr(c))
sorted_clauses = tuple(sorted(self.clauses, key=sort_key))
if self.clauses != sorted_clauses:
object.__setattr__(self, "clauses", sorted_clauses)
[docs]
@dataclasses.dataclass(frozen=True)
class DiscoveredSegment:
"""HTE discovery output. Tight — no optional fields.
Produced by the HTE estimation pipeline (the embedded policy-tree fit
over posterior CATEs). The fitted estimator is NOT exposed —
downstream sees only this output.
Level: cross-cutting (segment × experiment).
Fields:
id: Leaf id within the parent policy tree. Identifies the
segment's tree position so ``PolicyTreeResult.allocation_map[id]``
lookups work.
rule: The segment-defining rule.
gate_estimate: Estimated treatment effect for this segment (metric-
native units).
gate_ci: 80% credible/confidence interval for the gate estimate.
population_share: Fraction of the population assigned to this
segment, in [0, 1].
stability_score: Bootstrap-replicability score in [0, 1] — the
fraction of bootstrap tree refits in which some leaf has
Jaccard overlap >= 0.5 with this segment's member set.
0.80 is the documented default "credible enough to act on"
cutoff. ``NaN`` is the documented "not computed" sentinel
(e.g. ``fit_policy_tree(n_bootstrap=0)``).
arm_best_probabilities: Per-arm posterior probability that the arm
is best in this segment under the shared best-arm rule. Keyed
by ALL variant names INCLUDING control (control wins a draw
when every contrast is non-positive); values sum to 1.0 within
1e-6.
"""
CONTRACT_VERSION: ClassVar[int] = 2
id: int
rule: SegmentRule
gate_estimate: float
gate_ci: tuple[float, float]
population_share: float
stability_score: float
arm_best_probabilities: dict[str, float]
def __post_init__(self) -> None:
if self.gate_ci[0] > self.gate_ci[1]:
raise ValueError(
f"DiscoveredSegment: gate_ci must be ordered (low, high), "
f"got {self.gate_ci}"
)
if not (0.0 <= self.population_share <= 1.0):
raise ValueError(
f"DiscoveredSegment: population_share must be in [0, 1], "
f"got {self.population_share}"
)
if not math.isnan(self.stability_score) and not (
0.0 <= self.stability_score <= 1.0
):
raise ValueError(
f"DiscoveredSegment: stability_score must be in [0, 1] or "
f"NaN (not computed), got {self.stability_score}"
)
if not self.arm_best_probabilities:
raise ValueError(
"DiscoveredSegment: arm_best_probabilities must be non-empty"
)
for arm, prob in self.arm_best_probabilities.items():
if not (0.0 <= prob <= 1.0):
raise ValueError(
f"DiscoveredSegment: arm_best_probabilities[{arm!r}] "
f"must be in [0, 1], got {prob}"
)
total = sum(self.arm_best_probabilities.values())
if abs(total - 1.0) > 1e-6:
raise ValueError(
f"DiscoveredSegment: arm_best_probabilities values must sum "
f"to 1.0 within 1e-6, got {total}"
)
def _row(self) -> str:
"""Single-line summary; embedded by AnalysisResult / PolicyTreeResult."""
leader, leader_p = max(
self.arm_best_probabilities.items(), key=lambda item: item[1]
)
return (
f"{self.rule.description}"
f" | gate {self.gate_estimate:+.4f}"
f" [{self.gate_ci[0]:+.4f}, {self.gate_ci[1]:+.4f}]"
f" | share {self.population_share:.0%}"
f" | stability {self.stability_score:.2f}"
f" | leader {leader} P={leader_p:.2f}"
)
def __repr__(self) -> str:
return f"DiscoveredSegment({self._row()})"
[docs]
@dataclasses.dataclass(frozen=True)
class DiscoveryProvenance:
"""Compact snapshot of HTE discovery origin.
Avoids bloating/duplicating ``DiscoveredSegment`` when carried into
a ``RegisteredSegment``.
"""
gate_estimate: float
stability_score: float
population_share: float
discovered_at: datetime
[docs]
@dataclasses.dataclass(frozen=True)
class RegisteredSegment:
"""Operator-reviewed, registry-registered segment.
Level: cross-cutting (segment in registry).
Lifecycle::
HTE discovery → DiscoveredSegment
→ operator review → RegisteredSegment(lifecycle="registered")
→ dbt SQL → Redis → RegisteredSegment(lifecycle="deployed")
→ experiment targeting
Fields:
key: Snake_case registry identifier (e.g. ``"high_value_mobile"``).
rule: The segment-defining rule.
provenance: Discovery origin, or ``None`` for manually-defined
segments.
lifecycle: Current lifecycle stage.
"""
CONTRACT_VERSION: ClassVar[int] = 1
key: str
rule: SegmentRule
provenance: DiscoveryProvenance | None
lifecycle: Literal["registered", "deployed"]
# ---------------------------------------------------------------------------
# Variant-level contract (Level 2)
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class VariantData:
"""Per-visitor observations for a single experiment variant.
Level: variant.
The ``visitors`` DataFrame MUST conform to ``VISITOR_SCHEMA`` (at
minimum). Additional feature columns are permitted.
Fields:
name: Variant name (e.g. ``"control"``, ``"treatment_a"``).
visitors: DataFrame with one row per visitor.
n_visitors: Row count (redundant with ``len(visitors)`` — fail-
closed validation).
n_conversions: Count of ``converted == True`` rows.
total_revenue: Sum of ``revenue`` column.
"""
name: str
visitors: pd.DataFrame
n_visitors: int
n_conversions: int
total_revenue: float
def __post_init__(self) -> None:
if len(self.visitors) != self.n_visitors:
raise ValueError(
f"VariantData '{self.name}': len(visitors)={len(self.visitors)} "
f"!= n_visitors={self.n_visitors}"
)
# ---------------------------------------------------------------------------
# Experiment-level contract (Level 3) — observed data
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class ObservedExperimentData:
"""Input data for a single experiment analysis run.
Level: experiment.
This type has NO truth field. Ground truth is structurally excluded
so that analysis code cannot peek at it. Generators produce a
``CalibrationBundle`` that pairs observed data with truth; the
calibration runner unpacks the bundle and passes only the observed
data to ``analyze()``.
Production path: ``load_experiment()`` returns this directly.
Simulation path: ``generate()`` → ``CalibrationBundle`` → runner unpacks.
Fields:
experiment_id: Unique experiment identifier.
metric: Canonical metric name (e.g. ``"revenue_per_visitor"``).
variants: List of variant data, minimum 2. The first variant is
conventionally the control/baseline.
Derived accessors (read-only properties, not dataclass fields):
``control_name`` — the first variant's name (the control/reference
variant); ``treatment_names`` — names of all non-control variants,
in variant-list order.
"""
CONTRACT_VERSION: ClassVar[int] = 1
experiment_id: str
metric: str
variants: list[VariantData]
def __post_init__(self) -> None:
if len(self.variants) < 2:
raise ValueError(
"ObservedExperimentData requires at least 2 variants, "
f"got {len(self.variants)}"
)
names = [v.name for v in self.variants]
if len(names) != len(set(names)):
raise ValueError(
f"Duplicate variant names: {names}"
)
@property
def control_name(self) -> str:
"""Name of the reference/control variant (``variants[0].name``)."""
return self.variants[0].name
@property
def treatment_names(self) -> tuple[str, ...]:
"""Names of the non-control variants in variant-list order.
Returns a tuple of ``variants[1:].name`` values. At K=2 this is a
single-element tuple; at K≥3 it carries all treatment variant names.
"""
return tuple(v.name for v in self.variants[1:])
# ---------------------------------------------------------------------------
# Alignment contract (cross-cutting)
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class AlignedVisitorArray:
"""Array aligned 1:1 with concatenated visitor rows.
Any per-visitor array (e.g. CATE predictions) MUST be wrapped in this
type to enforce explicit alignment with the concatenated visitor rows::
visitors = pd.concat([v.visitors for v in data.variants],
ignore_index=True)
assert len(array.values) == len(visitors)
# array.values[i] corresponds to visitors.iloc[i]
The type name IS the documentation — when an agent sees
``cate_per_visitor: AlignedVisitorArray``, the alignment contract is
self-evident.
Fields:
values: The per-visitor array.
n_visitors: Expected length (redundant — fail-closed validation).
"""
values: np.ndarray
n_visitors: int
def __post_init__(self) -> None:
if self.values.ndim != 1:
raise ValueError(
f"AlignedVisitorArray: values must be 1-D (scalar per visitor), "
f"got {self.values.ndim}-D with shape {self.values.shape}"
)
if len(self.values) != self.n_visitors:
raise ValueError(
f"AlignedVisitorArray: len(values)={len(self.values)} "
f"!= n_visitors={self.n_visitors}"
)
# ---------------------------------------------------------------------------
# Experiment-level contracts (Level 3) — analysis results
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class DecompositionSamples:
"""Posterior samples for frequency and severity lift components.
Only meaningful for hurdle metrics (``MetricFamily.HURDLE_REAL``).
Frequency = conversion probability lift. Severity = AOV lift given
conversion.
Fields:
frequency_lift_samples: Per-sample frequency component lift.
severity_lift_samples: Per-sample severity component lift.
"""
frequency_lift_samples: np.ndarray
severity_lift_samples: np.ndarray
[docs]
@dataclasses.dataclass(frozen=True)
class ComparisonResult:
"""Posterior comparison between two variants.
Level: experiment.
Uses **role-based naming**: ``baseline`` and ``comparison`` are roles
within THIS comparison, not properties of the variants themselves.
The same variant can play different roles in different comparisons.
Lift semantics: ``lift_samples`` always contains **absolute** lift
(``comparison - baseline``) in metric-native units. ``lift_unit``
indicates the metric's natural presentation unit (``"pct"`` for
binary metrics, ``"dollar"`` for revenue metrics) so display layers
can derive percentage lift when rendering.
Fields:
baseline: Variant name serving as the baseline in this comparison.
comparison: Variant name being compared to the baseline.
method: ``"compare_to_control"`` or ``"best_of_rest"``.
probability_positive: P(comparison > baseline).
probability_better: P(comparison > baseline + threshold).
probability_harmful: P(baseline > comparison + threshold).
expected_loss_baseline: E[max(comparison - baseline, 0)] — cost of
choosing baseline when comparison is better.
expected_loss_comparison: E[max(baseline - comparison, 0)] — cost of
choosing comparison when baseline is better.
expected_loss_samples_baseline: Per-sample loss array for baseline.
expected_loss_samples_comparison: Per-sample loss array for comparison.
lift_samples: Absolute lift samples (``comparison - baseline``) in
metric-native units. Always absolute regardless of metric family.
lift_unit: Metric's natural presentation unit (``"pct"`` for binary,
``"dollar"`` for revenue). Display hint only — ``lift_samples``
is always absolute.
lift_ci: (low, high) credible interval for lift.
lift_ci_level: CI level (default 0.80).
decomposition: Frequency/severity decomposition (hurdle metrics only).
"""
baseline: str
comparison: str
method: Literal["compare_to_control", "best_of_rest"]
probability_positive: float
probability_better: float
probability_harmful: float
expected_loss_baseline: float
expected_loss_comparison: float
expected_loss_samples_baseline: np.ndarray
expected_loss_samples_comparison: np.ndarray
lift_samples: np.ndarray
lift_unit: str
lift_ci: tuple[float, float]
lift_ci_level: float = 0.80
decomposition: DecompositionSamples | None = None
[docs]
@dataclasses.dataclass(frozen=True)
class ChannelLift:
"""Point estimate + interval for a single hurdle channel's lift.
Level: experiment.
Fields:
point_estimate: Posterior mean of the channel-specific lift.
ci: 80% credible interval (low, high) on the channel-specific lift.
"""
point_estimate: float
ci: tuple[float, float]
def __post_init__(self) -> None:
if self.ci[0] > self.ci[1]:
raise ValueError(
f"ChannelLift: ci must be ordered (low, high), got {self.ci}"
)
[docs]
@dataclasses.dataclass(frozen=True)
class Decomposition:
"""Conversion/severity decomposition of a hurdle-metric lift.
Lean summary counterpart of ``DecompositionSamples`` — point estimates
and intervals only, no posterior samples. Populated on ``Comparison``
for hurdle posteriors (``posterior.has_decomposition() == True``).
Level: experiment.
Fields:
conversion_lift: Change in conversion probability attributable to
the treatment.
severity_lift: Change in basket size given conversion attributable
to the treatment.
"""
conversion_lift: ChannelLift
severity_lift: ChannelLift
def __repr__(self) -> str:
conv, sev = self.conversion_lift, self.severity_lift
return (
f"conversion: {conv.point_estimate:+.4f} "
f"[{conv.ci[0]:+.4f}, {conv.ci[1]:+.4f}]"
f" severity: {sev.point_estimate:+.4f} "
f"[{sev.ci[0]:+.4f}, {sev.ci[1]:+.4f}]"
)
[docs]
@dataclasses.dataclass(frozen=True)
class Comparison:
"""Lean per-treatment global contrast vs the reference arm.
The v0.2 summary surface carried by ``AnalysisResult.comparisons`` —
point estimates and probabilities only. The rich sample-carrying
``ComparisonResult`` stays the ``compare.variants`` output; anything
needing posterior samples goes through ``AnalysisResult.posterior``.
Level: experiment.
Fields:
treatment: Treatment variant name being compared (matches a name
in ``posterior.observed.treatment_names``).
probability_positive: P(lift > 0) at the global level.
lift_estimate: Posterior mean of the CATE for this contrast.
lift_ci: 80% credible interval on the lift (10th/90th percentile
of ``rpv_cate_samples`` for this contrast).
decomposition: Conversion/severity decomposition (hurdle
posteriors only; ``None`` otherwise).
"""
treatment: str
probability_positive: float
lift_estimate: float
lift_ci: tuple[float, float]
decomposition: Decomposition | None = None
def __post_init__(self) -> None:
if self.lift_ci[0] > self.lift_ci[1]:
raise ValueError(
f"Comparison: lift_ci must be ordered (low, high), "
f"got {self.lift_ci}"
)
def __repr__(self) -> str:
text = (
f"Comparison — {self.treatment} vs control\n"
f" lift: {self.lift_estimate:+.4f}"
f" 80% CI [{self.lift_ci[0]:+.4f}, {self.lift_ci[1]:+.4f}]"
f" P(lift > 0) = {self.probability_positive:.2f}"
)
if self.decomposition is not None:
text += f"\n decomposition — {self.decomposition!r}"
return text
[docs]
@dataclasses.dataclass(frozen=True)
class DecisionThresholds:
"""Decision thresholds for recommendation summaries.
All values are probabilities in (0, 1) except ``expected_loss_tolerance``
which is a positive metric-native value.
"""
expected_loss_tolerance: float = 0.01
p_positive_threshold: float = 0.95
p_better_threshold: float = 0.80
futility_threshold: float = 0.05
harm_threshold: float = 0.90
def __post_init__(self) -> None:
for name in (
"p_positive_threshold",
"p_better_threshold",
"futility_threshold",
"harm_threshold",
):
val = getattr(self, name)
if not (0.0 < val < 1.0):
raise ValueError(
f"{name} must be in (0, 1), got {val}"
)
if self.expected_loss_tolerance <= 0.0:
raise ValueError(
f"expected_loss_tolerance must be positive, "
f"got {self.expected_loss_tolerance}"
)
[docs]
@dataclasses.dataclass(frozen=True)
class RecommendationSummary:
"""Recommended decision with its decision-theoretic evidence.
The act-now risk assessment for one treatment-vs-control contrast:
what committing to either side costs in expectation, how confident
the posterior is, what one more round of data is worth — and the
default rule's resulting SHIP / CONTINUE / STOP call. A pure summary
of the posterior (no sample arrays); recomputable from any posterior,
globally or per-segment.
Level: experiment.
Fields:
treatment: The treatment variant this summary is for (the
contrast's non-control side).
decision: Ship, continue, or stop.
expected_loss_baseline: Expected loss of choosing baseline.
expected_loss_comparison: Expected loss of choosing comparison.
probability_positive: P(comparison > baseline).
probability_better: P(comparison meaningfully better).
probability_harmful: P(comparison meaningfully harmful).
thresholds: Decision thresholds used (e.g.
``{"expected_loss_tolerance": 0.001, ...}``).
expected_value_of_one_more_round: Information-theoretic value of
running one more round of data at the same per-round n, in
expected-loss-reduction units (loss/visitor). ``NaN`` means
the producer did not compute it (the legacy
``compare.variants`` path cannot — a ``ComparisonResult``
carries no sample-size information). Formula documented in
``docs/concepts/decision-theoretic-inputs.md``.
"""
treatment: str
decision: Decision
expected_loss_baseline: float
expected_loss_comparison: float
probability_positive: float
probability_better: float
probability_harmful: float
thresholds: dict[str, float]
expected_value_of_one_more_round: float = dataclasses.field(
default=float("nan"), kw_only=True
)
def __repr__(self) -> str:
return (
f"RecommendationSummary — {self.treatment} vs control\n"
f" decision: {self.decision.name}\n"
f" expected loss — ship now: {self.expected_loss_comparison:.4f}"
f" keep control: {self.expected_loss_baseline:.4f}\n"
f" P(lift > 0) = {self.probability_positive:.2f}"
f" P(better) = {self.probability_better:.2f}"
f" P(harmful) = {self.probability_harmful:.2f}\n"
f" value of one more round: "
f"{self.expected_value_of_one_more_round:.4f}"
)
[docs]
@dataclasses.dataclass(frozen=True)
class AnalysisResult:
"""Summary analysis surface returned by ``posterior.analyze()``.
Level: experiment.
This is the SUMMARY surface — lean point estimates and probabilities
(``Comparison`` entries, discovered segments, the global
``RecommendationSummary``). Anything needing posterior samples goes
through ``posterior``
(e.g. ``analysis.posterior.rpv_cate_samples``); observed data is
reachable as ``analysis.posterior.observed``.
Fields:
experiment_id: Experiment identifier.
metric: Metric analyzed.
comparisons: One lean ``Comparison`` per non-reference treatment.
segments: Segments discovered by the embedded policy-tree fit.
Non-optional — an empty list when no segment cleared the
min_segment_share threshold, never ``None``.
recommendation: Global ``RecommendationSummary`` (the extended
shape with ``expected_value_of_one_more_round``). At K ≥ 3 it
is computed for the best challenger (largest global posterior-
mean contrast).
cate_per_visitor: Posterior-mean CATE per visitor, aligned with
concatenated visitor rows. Shape ``(n,)`` at K = 2;
``(n, K − 1)`` per-arm contrasts vs the reference at K ≥ 3.
analyzed_at: Timestamp of analysis completion.
posterior: The fitted posterior the analysis derives from
(``repr=False`` — large sample arrays).
"""
CONTRACT_VERSION: ClassVar[int] = 2
experiment_id: str
metric: str
comparisons: list[Comparison]
segments: list[DiscoveredSegment]
recommendation: RecommendationSummary
cate_per_visitor: np.ndarray
analyzed_at: datetime
posterior: HurdleBCFResult | ContinuousBCFResult | BinaryBCFResult = (
dataclasses.field(repr=False)
)
@property
def is_calibrated(self) -> bool:
"""Whether the underlying posterior has a calibration applied.
Delegates to ``posterior.is_calibrated``.
"""
return self.posterior.is_calibrated
def __repr__(self) -> str:
lines = [
f"AnalysisResult — {self.experiment_id} · {self.metric}"
f" · analyzed {self.analyzed_at:%Y-%m-%d %H:%M}"
+ (" [calibrated]" if self.is_calibrated else ""),
" comparisons:",
]
for c in self.comparisons:
lines.append(
f" {c.treatment} vs control: {c.lift_estimate:+.4f}"
f" [{c.lift_ci[0]:+.4f}, {c.lift_ci[1]:+.4f}]"
f" P(lift > 0) = {c.probability_positive:.2f}"
)
lines.append(f" segments ({len(self.segments)}):")
for seg in self.segments:
lines.append(f" {seg._row()}")
lines.append(
f" recommendation: {self.recommendation.decision.name}"
f" — {self.recommendation.treatment}"
)
lines.append(
f" cate_per_visitor: posterior-mean CATE,"
f" shape {self.cate_per_visitor.shape}"
)
return "\n".join(lines)
def _repr_html_(self) -> str:
import html as _html
# Rendering through pandas keeps the `dataframe` CSS class, which
# Jupyter, Colab, and the doc theme all style.
# Point and interval share a column as one non-breaking token —
# split columns wrap mid-bracket at narrow viewports.
comparisons = pd.DataFrame(
[
{
"comparison": f"{c.treatment} vs control",
"lift (80% CI)": (
f"{c.lift_estimate:+.4f} "
f"[{c.lift_ci[0]:+.4f}, {c.lift_ci[1]:+.4f}]"
).replace(" ", "\u00a0"),
"P(lift > 0)": f"{c.probability_positive:.2f}",
}
for c in self.comparisons
]
)
segment_rows = []
for seg in self.segments:
leader, leader_p = max(
seg.arm_best_probabilities.items(), key=lambda item: item[1]
)
segment_rows.append(
{
"segment": seg.rule.description,
"share": f"{seg.population_share:.0%}",
"GATE (80% CI)": (
f"{seg.gate_estimate:+.4f} "
f"[{seg.gate_ci[0]:+.4f}, {seg.gate_ci[1]:+.4f}]"
).replace(" ", "\u00a0"),
"stability": f"{seg.stability_score:.2f}",
"leader": f"{leader} P={leader_p:.2f}".replace(
" ", "\u00a0"
),
}
)
header = _html.escape(
f"AnalysisResult — {self.experiment_id} · {self.metric}"
) + (" <i>[calibrated]</i>" if self.is_calibrated else "")
return (
f"<div><b>{header}</b>"
+ comparisons.to_html(index=False, border=0)
+ pd.DataFrame(segment_rows).to_html(index=False, border=0)
+ f"<p>recommendation: {self.recommendation.decision.name}"
+ f" — {_html.escape(self.recommendation.treatment)}</p></div>"
)
[docs]
class DiagnosticsBundle(NamedTuple):
"""Layer 3: PyMC internals. Not part of the analysis contract.
Transparent container — callers that don't need diagnostics simply
ignore the second element::
result, _ = analyze(data)
Following the ArviZ opinionated Bayes workflow, diagnostics are not
optional — every analysis produces traces. ``analyze()`` always
returns ``tuple[AnalysisResult, DiagnosticsBundle]``.
"""
inference_data: InferenceData
# ---------------------------------------------------------------------------
# Truth boundary contracts (simulation only)
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class CalibrationTruth:
"""Ground truth for a single calibration/simulation run.
This type exists ONLY in the simulation/calibration path. Production
analysis never sees it. The type boundary enforces this::
analyze(ObservedExperimentData) -> AnalysisResult # no truth
calibrate(AnalysisResult, CalibrationTruth) -> CalibrationRecord
**K=2 dispatch:** the legacy 1-D fields (``cate_per_visitor``,
``conv_cate_per_visitor``, ``aov_cate_per_visitor``,
``p0_per_visitor``, ``p1_per_visitor``, ``m0_per_visitor``,
``m1_per_visitor``) are populated and the three new list fields
(``contrast_cate_per_visitor``, ``p_per_visitor``, ``m_per_visitor``)
are ``None``.
**K≥3 dispatch:** ``cate_per_visitor`` is ``None``; the legacy paired
fields (``p0/p1/m0/m1_per_visitor``) are ``None``.
``contrast_cate_per_visitor`` (length K−1) carries the per-treatment
effects (each treatment level vs. control, the heterogeneous CATEs);
``p_per_visitor`` and ``m_per_visitor`` (each length K) carry the
per-visitor potential-outcome truth under each treatment level (index
0 = control).
Fields:
effect: Absolute metric-native treatment effect (e.g. +$0.12 RPV).
metric_id: Canonical metric name.
metric_family: Abstract family taxonomy value.
effect_components: Decomposition by named component
(e.g. ``{"conv_effect": 0.02, "aov_effect": 0.10}``).
cate_per_visitor: Per-visitor true CATEs, aligned with concatenated
visitor rows. Populated at K=2; ``None`` at K≥3.
conv_cate_per_visitor: Per-visitor conversion CATE (p1 - p0) * m0.
Hurdle K=2 only; ``None`` for binary or K≥3.
aov_cate_per_visitor: Per-visitor AOV CATE p1 * (m1 - m0).
Hurdle K=2 only; ``None`` for binary or K≥3.
p0_per_visitor: Per-visitor control conversion probabilities.
Hurdle K=2 only; ``None`` for binary or K≥3.
p1_per_visitor: Per-visitor treatment conversion probabilities.
Hurdle K=2 only; ``None`` for binary or K≥3.
m0_per_visitor: Per-visitor control severity means.
Hurdle K=2 only; ``None`` for binary or K≥3.
m1_per_visitor: Per-visitor treatment severity means.
Hurdle K=2 only; ``None`` for binary or K≥3.
contrast_cate_per_visitor: Per-treatment-effect per-visitor CATEs (K≥3).
Length K−1 list (one entry per treatment level vs. control); each
is the heterogeneous treatment effect realized on the visitor rows.
``None`` at K=2.
p_per_visitor: Per-visitor conversion potential outcomes under each
treatment level (K≥3). Length K, index 0 = control. ``None`` at K=2.
m_per_visitor: Per-visitor severity potential outcomes under each
treatment level (K≥3). Length K, index 0 = control. ``None`` at K=2.
"""
CONTRACT_VERSION: ClassVar[int] = 2
effect: float
metric_id: str
metric_family: MetricFamily
effect_components: dict[str, float]
cate_per_visitor: AlignedVisitorArray | None
conv_cate_per_visitor: AlignedVisitorArray | None = None
aov_cate_per_visitor: AlignedVisitorArray | None = None
p0_per_visitor: AlignedVisitorArray | None = None
p1_per_visitor: AlignedVisitorArray | None = None
m0_per_visitor: AlignedVisitorArray | None = None
m1_per_visitor: AlignedVisitorArray | None = None
contrast_cate_per_visitor: list[AlignedVisitorArray] | None = dataclasses.field(default=None, kw_only=True)
p_per_visitor: list[AlignedVisitorArray] | None = dataclasses.field(default=None, kw_only=True)
m_per_visitor: list[AlignedVisitorArray] | None = dataclasses.field(default=None, kw_only=True)
[docs]
class CalibrationBundle(NamedTuple):
"""Transparent container pairing observed data with ground truth.
Unpackable::
observed, truth = bundle
Generators produce this. The calibration runner unpacks it, passes
``observed`` to ``analyze()`` (which cannot see truth), then evaluates
the result against ``truth`` separately.
"""
observed: ObservedExperimentData
truth: CalibrationTruth
# ---------------------------------------------------------------------------
# Program-level contract (Level 4)
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class CalibrationRecord:
"""Per-seed evaluation record produced by the calibration pipeline.
Level: program.
Output of ``calibrate(AnalysisResult, CalibrationTruth, oracle_config)``.
All fields are JSON-serializable (no numpy arrays, no callables).
Uses agent-proof naming:
- ``analysis_mode`` is a ``ClaimLevel`` (not bare string).
- ``metric_family`` is a ``MetricFamily`` (not bare string).
- ``decision`` is a ``Decision`` (not bare string).
- ``est_lift_mean`` → ``estimated_lift`` (clearer).
Fields:
scenario_id: Identifier for the simulation scenario.
seed: Random seed for this run.
analysis_mode: Evidentiary claim level of the analysis.
effect: True planted treatment effect (from truth).
metric_id: Canonical metric name.
metric_family: Abstract family taxonomy.
effect_components: True effect decomposition (from truth).
estimator_id: Model/estimator used (e.g. ``"hurdle_lognormal"``).
estimated_lift: Posterior mean of the absolute lift estimate
(metric-native units, matching ``effect``).
ci_low: Lower bound of the credible interval.
ci_high: Upper bound of the credible interval.
ci_level: CI level (e.g. 0.80).
probability_positive: P(treatment > baseline).
probability_better: P(comparison > baseline + threshold).
probability_harmful: P(baseline > comparison + threshold).
expected_loss_baseline: Expected loss of choosing baseline.
expected_loss_comparison: Expected loss of choosing treatment.
decision: Recommended decision made.
oracle_decision: The decision the oracle would have made given the
true effect. Always a concrete ``Decision`` value — never
``None``. Persisted directly from ``_oracle_decision()`` so
downstream consumers (scorecard, notebooks) never need to
re-infer it from ``decision`` + ``decision_correct``.
decision_correct: Whether the decision was correct given truth
(``None`` if correctness is ambiguous, e.g. true effect near
zero and decision is ``CONTINUE``).
regret: Magnitude of decision error in metric-native units
(``None`` if not applicable).
"""
CONTRACT_VERSION: ClassVar[int] = 1
scenario_id: str
seed: int
analysis_mode: ClaimLevel
effect: float
metric_id: str
metric_family: MetricFamily
effect_components: dict[str, float]
estimator_id: str
estimated_lift: float
ci_low: float
ci_high: float
ci_level: float
probability_positive: float
probability_better: float
probability_harmful: float
expected_loss_baseline: float
expected_loss_comparison: float
decision: Decision
oracle_decision: Decision
decision_correct: bool | None
regret: float | None
# ---------------------------------------------------------------------------
# Revenue model contracts (simulation only)
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class ProductCategory:
"""A single product category in a cart-based revenue model.
Fields:
name: Category identifier (e.g. ``"budget"``, ``"mid"``, ``"premium"``).
base_price: Mean price for this category (in dollars). Must be > 0.
price_std: Standard deviation of within-category price variation.
Actual price is drawn from ``Normal(base_price, price_std)``,
clipped to ``base_price / 2`` minimum (no near-zero prices).
Use ``0.0`` for deterministic prices. Must be >= 0.
base_purchase_prob: Baseline Bernoulli purchase probability for this
category (before visitor-level affinity and treatment adjustments).
Must be in ``(0, 1]``.
"""
name: str
base_price: float
price_std: float
base_purchase_prob: float
def __post_init__(self) -> None:
if self.base_price <= 0:
raise ValueError(
f"ProductCategory({self.name!r}): base_price must be > 0, "
f"got {self.base_price}"
)
if self.price_std < 0:
raise ValueError(
f"ProductCategory({self.name!r}): price_std must be >= 0, "
f"got {self.price_std}"
)
if not (0 < self.base_purchase_prob <= 1):
raise ValueError(
f"ProductCategory({self.name!r}): base_purchase_prob must be "
f"in (0, 1], got {self.base_purchase_prob}"
)
[docs]
@dataclasses.dataclass(frozen=True)
class CartRevenueConfig:
"""Cart-based revenue model configuration.
Revenue for a converter is computed as the sum of prices for categories
where a per-visitor Bernoulli event fires. The purchase probability for
category ``j`` and visitor ``i`` is::
purchase_prob_j(i) = sigmoid(
logit(base_purchase_prob_j)
+ visitor_affinity_j(i)
+ effect_scale * treatment_delta_j(i)
)
The cart sampler distributes the severity surface scalar shift across
categories proportionally to each category's ``base_purchase_prob``
(see design doc D9).
When all Bernoulli events fail (empty cart), a minimum-purchase fallback
forces the cheapest category.
Fields:
categories: Ordered list of product categories. Must be non-empty.
base_quantity_mu: Mean of per-converter quantity distribution. Must be > 0.
base_quantity_sigma: Std of per-converter quantity distribution. Must be >= 0.
"""
categories: list[ProductCategory]
base_quantity_mu: float = 1.0
base_quantity_sigma: float = 0.0
def __post_init__(self) -> None:
if len(self.categories) < 1:
raise ValueError(
"CartRevenueConfig: categories must have at least 1 entry, "
f"got {len(self.categories)}"
)
if self.base_quantity_mu <= 0:
raise ValueError(
f"CartRevenueConfig: base_quantity_mu must be > 0, "
f"got {self.base_quantity_mu}"
)
if self.base_quantity_sigma < 0:
raise ValueError(
f"CartRevenueConfig: base_quantity_sigma must be >= 0, "
f"got {self.base_quantity_sigma}"
)