Source code for pytyche.generators.api

"""v2 data generator — typed output via CalibrationBundle.

Adapts the v1 segment experiment generator (``simulation.core``) to produce
v2 typed contracts.  Key differences from v1:

- ``converted`` column is ``bool`` (not ``int``).
- Returns ``CalibrationBundle(observed, truth)`` instead of ``ExperimentData``
  with an untyped ``ground_truth`` dict.
- ``CalibrationTruth.metric_family`` is a typed ``MetricFamily`` enum — no
  dependency on v1 ``get_metric_spec()``.
- ``validate_observed_data()`` runs before return — fail-closed.
"""

from __future__ import annotations

import math

import numpy as np
import pandas as pd

from pytyche.contracts import (
    AlignedVisitorArray,
    CalibrationBundle,
    CalibrationTruth,
    MetricFamily,
    ObservedExperimentData,
    VariantData,
)
from pytyche.validation import validate_observed_data

_VALID_METRICS = ("conversion_rate", "revenue_per_visitor")

_METRIC_FAMILIES: dict[str, MetricFamily] = {
    "conversion_rate": MetricFamily.BINARY,
    "revenue_per_visitor": MetricFamily.HURDLE_REAL,
}


def _metric_to_family(metric: str) -> MetricFamily:
    """Map a metric id to its MetricFamily enum value."""
    try:
        return _METRIC_FAMILIES[metric]
    except KeyError:
        raise ValueError(
            f"Unknown metric {metric!r}. Expected one of {sorted(_METRIC_FAMILIES)}"
        ) from None


def _make_visitors_df(
    n: int,
    experiment_id: str,
    variant_name: str,
    converted: np.ndarray,
    revenue: np.ndarray,
    extra_cols: dict[str, np.ndarray] | None = None,
) -> pd.DataFrame:
    """Build a visitor DataFrame conforming to VISITOR_SCHEMA.

    ``converted`` is cast to ``bool`` (v2 contract).
    """
    data: dict[str, object] = {
        "visitor_id": [f"{variant_name}-v{i}" for i in range(n)],
        "experiment_id": experiment_id,
        "variant": variant_name,
        "converted": converted.astype(bool),
        "revenue": revenue.astype(float),
        "orders_count": converted.astype(int),
        "sessions_count": np.ones(n, dtype=int),
    }
    if extra_cols:
        data.update(extra_cols)
    return pd.DataFrame(data)


def _make_variant(
    name: str,
    experiment_id: str,
    converted: np.ndarray,
    revenue: np.ndarray,
    extra_cols: dict[str, np.ndarray] | None = None,
) -> VariantData:
    """Construct a VariantData from raw arrays."""
    n = len(converted)
    visitors = _make_visitors_df(
        n, experiment_id, name, converted, revenue, extra_cols
    )
    return VariantData(
        name=name,
        visitors=visitors,
        n_visitors=n,
        n_conversions=int(converted.astype(bool).sum()),
        total_revenue=float(revenue.sum()),
    )


_REQUIRED_SEGMENT_KEYS = {"pct", "base_conv", "treatment_effect"}
_RPV_SEGMENT_KEYS = {"aov_mu", "aov_sigma"}


def _validate_segments(segments: dict[str, dict], metric: str) -> None:
    """Validate segment config before sampling.

    Raises ValueError with a clear message for:
    - Empty segments dict
    - Missing required keys per segment
    - Non-positive pct values
    - Zero-sum pct (can't normalize to a probability distribution)
    - Missing RPV-specific keys when metric is revenue_per_visitor
    """
    if not segments:
        raise ValueError("segments must be non-empty")

    for name, cfg in segments.items():
        missing = _REQUIRED_SEGMENT_KEYS - set(cfg)
        if missing:
            raise ValueError(
                f"Segment '{name}': missing required keys {sorted(missing)}"
            )
        if cfg["pct"] <= 0:
            raise ValueError(
                f"Segment '{name}': pct must be positive, got {cfg['pct']}"
            )

        if metric == "revenue_per_visitor":
            missing_rpv = _RPV_SEGMENT_KEYS - set(cfg)
            if missing_rpv:
                raise ValueError(
                    f"Segment '{name}': metric 'revenue_per_visitor' requires "
                    f"keys {sorted(missing_rpv)}"
                )

    total_pct = sum(cfg["pct"] for cfg in segments.values())
    if total_pct <= 0:
        raise ValueError(
            f"Segment pct values sum to {total_pct}; must be positive"
        )


[docs] def generate( n_visitors: int, segments: dict[str, dict], metric: str = "conversion_rate", seed: int = 42, experiment_id: str = "sim-exp", ) -> CalibrationBundle: """Generate segment-aware experiment data with typed v2 output. Statistical model is identical to v1 ``generate_segment_experiment``: segment assignment via weighted ``rng.choice``, per-segment Bernoulli conversion, per-segment LogNormal revenue for converters. Ground truth is computed analytically from planted parameters. Parameters ---------- n_visitors: Total visitors, split evenly between control and treatment. segments: Mapping of segment name to config dict. Each config must have: ``pct``, ``base_conv``, ``treatment_effect``. For ``revenue_per_visitor``: also ``aov_mu``, ``aov_sigma``, and optionally ``treatment_aov_mu_shift``. metric: ``"conversion_rate"`` or ``"revenue_per_visitor"``. seed: Random seed for reproducibility. experiment_id: Identifier stored on the returned data. Returns ------- CalibrationBundle ``(observed, truth)`` — observed data validated via ``validate_observed_data`` before return. """ if metric not in _VALID_METRICS: raise ValueError( f"Unknown metric {metric!r}. Expected one of {_VALID_METRICS}" ) _validate_segments(segments, metric) rng = np.random.default_rng(seed) n_control = n_visitors // 2 n_treat = n_visitors - n_control seg_names = list(segments.keys()) pcts = np.array([segments[s]["pct"] for s in seg_names]) pcts = pcts / pcts.sum() # Assign segments using a dedicated child RNG. seg_rng = np.random.default_rng(rng.integers(0, 2**32)) ctrl_segments = seg_rng.choice(seg_names, size=n_control, p=pcts) treat_segments = seg_rng.choice(seg_names, size=n_treat, p=pcts) # Pre-generate per-variant, per-segment seeds for RNG isolation. ctrl_seg_seeds = {s: int(rng.integers(0, 2**32)) for s in seg_names} treat_seg_seeds = {s: int(rng.integers(0, 2**32)) for s in seg_names} def _generate_conversions_and_revenue( variant_segments: np.ndarray, apply_treatment: bool, variant_seg_seeds: dict[str, int], ) -> tuple[np.ndarray, np.ndarray]: n = len(variant_segments) converted = np.zeros(n, dtype=int) revenue = np.zeros(n, dtype=float) for seg_name in seg_names: cfg = segments[seg_name] mask = variant_segments == seg_name n_seg = int(mask.sum()) if n_seg == 0: continue seg_rng_local = np.random.default_rng(variant_seg_seeds[seg_name]) if apply_treatment: p = float( np.clip(cfg["base_conv"] + cfg["treatment_effect"], 0.0, 1.0) ) else: p = float(cfg["base_conv"]) seg_converted = seg_rng_local.binomial(1, p, size=n_seg) converted[mask] = seg_converted if metric == "revenue_per_visitor": aov_mu = cfg["aov_mu"] if apply_treatment: aov_mu += cfg.get("treatment_aov_mu_shift", 0.0) aov_sigma = cfg["aov_sigma"] n_conv = int(seg_converted.sum()) if n_conv > 0: seg_revenue = np.zeros(n_seg, dtype=float) conv_idx = np.where(seg_converted == 1)[0] seg_revenue[conv_idx] = seg_rng_local.lognormal( mean=aov_mu, sigma=aov_sigma, size=n_conv ) full_indices = np.where(mask)[0] revenue[full_indices] = seg_revenue return converted, revenue ctrl_converted, ctrl_revenue = _generate_conversions_and_revenue( ctrl_segments, False, ctrl_seg_seeds ) treat_converted, treat_revenue = _generate_conversions_and_revenue( treat_segments, True, treat_seg_seeds ) control = _make_variant( "control", experiment_id, ctrl_converted, ctrl_revenue, extra_cols={"segment": ctrl_segments}, ) treatment = _make_variant( "treatment", experiment_id, treat_converted, treat_revenue, extra_cols={"segment": treat_segments}, ) # --- Ground truth (analytical, not empirical) --- segment_rpv: dict[str, dict[str, float]] = {} if metric == "revenue_per_visitor": for seg_name in seg_names: cfg = segments[seg_name] aov_mu = cfg["aov_mu"] aov_sigma = cfg["aov_sigma"] e_lognormal_ctrl = math.exp(aov_mu + aov_sigma**2 / 2) ctrl_rpv = cfg["base_conv"] * e_lognormal_ctrl treat_aov_mu = aov_mu + cfg.get("treatment_aov_mu_shift", 0.0) e_lognormal_treat = math.exp(treat_aov_mu + aov_sigma**2 / 2) treat_conv = float( np.clip( cfg["base_conv"] + cfg["treatment_effect"], 0.0, 1.0 ) ) treat_rpv = treat_conv * e_lognormal_treat segment_rpv[seg_name] = { "control": ctrl_rpv, "treatment": treat_rpv, } truth_lift = sum( float(pcts[i]) * ( segment_rpv[seg_names[i]]["treatment"] - segment_rpv[seg_names[i]]["control"] ) for i in range(len(seg_names)) ) else: truth_lift = sum( float(pcts[i]) * ( float( np.clip( segments[seg_names[i]]["base_conv"] + segments[seg_names[i]]["treatment_effect"], 0.0, 1.0, ) ) - segments[seg_names[i]]["base_conv"] ) for i in range(len(seg_names)) ) # Effect components. effect_components: dict[str, float] if metric == "conversion_rate": effect_components = {"conv_effect": float(truth_lift)} else: weighted_conv_effect = sum( float(pcts[i]) * ( float( np.clip( segments[seg_names[i]]["base_conv"] + segments[seg_names[i]]["treatment_effect"], 0.0, 1.0, ) ) - segments[seg_names[i]]["base_conv"] ) for i in range(len(seg_names)) ) weighted_aov_effect = sum( float(pcts[i]) * segments[seg_names[i]].get("treatment_aov_mu_shift", 0.0) for i in range(len(seg_names)) ) effect_components = { "conv_effect": weighted_conv_effect, "aov_effect": weighted_aov_effect, } # Per-visitor segment-constant CATE aligned with [control, treatment]. all_segments = np.concatenate([ctrl_segments, treat_segments]) if metric == "revenue_per_visitor": seg_cate = { s: segment_rpv[s]["treatment"] - segment_rpv[s]["control"] for s in seg_names } else: seg_cate = { s: float( np.clip( segments[s]["base_conv"] + segments[s]["treatment_effect"], 0.0, 1.0, ) ) - segments[s]["base_conv"] for s in seg_names } cate_values = np.array([seg_cate[s] for s in all_segments]) n_total = n_control + n_treat cate_per_visitor = AlignedVisitorArray( values=cate_values, n_visitors=n_total ) truth = CalibrationTruth( effect=truth_lift, metric_id=metric, metric_family=_metric_to_family(metric), effect_components=effect_components, cate_per_visitor=cate_per_visitor, ) observed = ObservedExperimentData( experiment_id=experiment_id, metric=metric, variants=[control, treatment], ) # Fail-closed: validate before returning. validate_observed_data(observed) return CalibrationBundle(observed=observed, truth=truth)