"""Runtime validators for v2 contracts.
These validators make the contract types executable — they enforce
invariants at construction time rather than discovering violations
downstream.
Used by both generators and loaders. Fail-closed: no silent acceptance
of malformed data.
Functions:
- ``validate_observed_data(data)`` — checks visitor schema, dtypes,
invariants.
- ``validate_alignment(array, data)`` — confirms array length matches
visitors.
- ``validate_rule(rule, data)`` — confirms rule features exist and are
type-compatible.
"""
from __future__ import annotations
from typing import Any
import numpy as np
import pandas as pd
from pytyche.contracts import (
VISITOR_SCHEMA,
AlignedVisitorArray,
BetweenRule,
ComparisonRule,
EqRule,
InRule,
ObservedExperimentData,
RuleClause,
SegmentRule,
VariantData,
)
[docs]
class SchemaViolation(Exception):
"""Raised when observed data violates the visitor schema contract."""
[docs]
class AlignmentViolation(Exception):
"""Raised when a per-visitor array is misaligned with visitor rows."""
[docs]
class RuleViolation(Exception):
"""Raised when a segment rule references invalid features or types."""
[docs]
def validate_observed_data(
data: ObservedExperimentData,
*,
strict: bool = True,
) -> None:
"""Validate that all variant DataFrames conform to ``VISITOR_SCHEMA``.
Per-variant checks:
1. All required columns are present.
2. Column dtypes are compatible with the schema.
3. ``revenue >= 0`` for all rows.
4. No duplicate ``visitor_id`` within a variant.
5. ``n_visitors``, ``n_conversions``, ``total_revenue`` match the
DataFrame contents.
6. Every row's ``variant`` column matches ``VariantData.name``.
7. Every row's ``experiment_id`` matches ``data.experiment_id``.
Cross-variant checks:
8. No ``visitor_id`` appears in more than one variant (a visitor can
only be assigned to one arm).
Strict-mode checks (``strict=True``, the default):
9. All variants have the same set of extra feature columns (beyond
``VISITOR_SCHEMA``).
10. Extra feature columns have consistent dtypes across variants.
Set ``strict=False`` when feature-column asymmetry across variants
is intentional. Example: a treatment arm collects an extra survey
response column that doesn't exist for control::
# Treatment adds a post-checkout "why did you buy?" column.
# Control visitors never see the survey, so the column is absent.
validate_observed_data(data, strict=False)
Args:
data: The observed experiment data to validate.
strict: If ``True`` (default), require feature-column consistency
across variants. If ``False``, skip cross-variant column
checks.
Raises:
SchemaViolation: On any violation, with a message identifying the
variant and the specific problem.
"""
for variant in data.variants:
df = variant.visitors
_check_required_columns(variant.name, df)
_check_dtypes(variant.name, df)
_check_revenue_non_negative(variant.name, df)
_check_unique_visitor_ids(variant.name, df)
_check_summary_consistency(variant)
_check_row_variant_name(variant.name, df)
_check_row_experiment_id(variant.name, df, data.experiment_id)
_check_global_visitor_id_uniqueness(data)
if strict:
_check_feature_column_consistency(data)
[docs]
def validate_alignment(
array: AlignedVisitorArray,
data: ObservedExperimentData,
) -> None:
"""Confirm that a per-visitor array is aligned with concatenated visitors.
The expected length is the sum of ``n_visitors`` across all variants,
which equals the row count of::
pd.concat([v.visitors for v in data.variants], ignore_index=True)
Args:
array: The aligned array to check.
data: The experiment data providing the visitor count.
Raises:
AlignmentViolation: If the array length doesn't match.
"""
total_visitors = sum(v.n_visitors for v in data.variants)
if array.n_visitors != total_visitors:
raise AlignmentViolation(
f"AlignedVisitorArray has n_visitors={array.n_visitors}, "
f"but experiment has {total_visitors} total visitors "
f"across {len(data.variants)} variants"
)
[docs]
def validate_rule(
rule: SegmentRule,
data: ObservedExperimentData,
) -> None:
"""Confirm that a segment rule's features exist in visitor data and
have compatible types.
Checks:
1. Each clause's ``feature`` column exists in at least one variant's
DataFrame.
2. Numeric rules (``ComparisonRule``, ``BetweenRule``) reference numeric
columns.
3. Categorical rules (``EqRule``, ``InRule``) reference non-numeric
columns.
Does NOT enforce allowed categorical values — no domain registry in
Phase 1 scope.
Args:
rule: The segment rule to validate.
data: The experiment data providing the column schema.
Raises:
RuleViolation: If a feature is missing or type-incompatible.
"""
# Build a combined column→dtype map from all variants.
combined_dtypes: dict[str, Any] = {}
for variant in data.variants:
for col in variant.visitors.columns:
if col not in combined_dtypes:
combined_dtypes[col] = variant.visitors[col].dtype
for clause in rule.clauses:
_check_clause(clause, combined_dtypes)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _check_required_columns(variant_name: str, df: pd.DataFrame) -> None:
"""Check that all VISITOR_SCHEMA columns are present."""
missing = set(VISITOR_SCHEMA) - set(df.columns)
if missing:
raise SchemaViolation(
f"Variant '{variant_name}': missing required columns: "
f"{sorted(missing)}"
)
def _check_dtypes(variant_name: str, df: pd.DataFrame) -> None:
"""Check that required columns have compatible dtypes."""
for col, expected_dtype_str in VISITOR_SCHEMA.items():
actual = df[col].dtype
expected = np.dtype(expected_dtype_str)
# Allow compatible types: int32 for int64, float32 for float64, etc.
if expected.kind == "f":
# Float columns: accept any float kind.
if actual.kind != "f":
raise SchemaViolation(
f"Variant '{variant_name}': column '{col}' should be "
f"float, got {actual}"
)
elif expected.kind == "i":
# Integer columns: accept any integer kind (signed).
if actual.kind not in ("i", "u"):
raise SchemaViolation(
f"Variant '{variant_name}': column '{col}' should be "
f"integer, got {actual}"
)
elif expected.kind == "b":
# Boolean columns.
if actual.kind != "b":
raise SchemaViolation(
f"Variant '{variant_name}': column '{col}' should be "
f"bool, got {actual}"
)
elif expected.kind == "O":
# Object columns: accept object or string kind.
if actual.kind not in ("O", "U", "S"):
raise SchemaViolation(
f"Variant '{variant_name}': column '{col}' should be "
f"string/object, got {actual}"
)
def _check_revenue_non_negative(variant_name: str, df: pd.DataFrame) -> None:
"""Check that revenue >= 0 for all rows."""
if (df["revenue"] < 0).any():
n_neg = int((df["revenue"] < 0).sum())
raise SchemaViolation(
f"Variant '{variant_name}': {n_neg} rows have negative revenue"
)
def _check_unique_visitor_ids(variant_name: str, df: pd.DataFrame) -> None:
"""Check no duplicate visitor_id within a variant."""
if df["visitor_id"].duplicated().any():
n_dup = int(df["visitor_id"].duplicated().sum())
raise SchemaViolation(
f"Variant '{variant_name}': {n_dup} duplicate visitor_id values"
)
def _check_summary_consistency(variant: VariantData) -> None:
"""Check that summary fields match DataFrame contents."""
name = variant.name
df = variant.visitors
n_conversions = variant.n_conversions
total_revenue = variant.total_revenue
actual_conversions = int(df["converted"].sum().item())
if actual_conversions != n_conversions:
raise SchemaViolation(
f"Variant '{name}': n_conversions={n_conversions} but "
f"DataFrame has {actual_conversions} converted rows"
)
actual_revenue = float(df["revenue"].sum().item())
# Float comparison with tolerance for summation rounding.
if abs(actual_revenue - total_revenue) > 1e-6:
raise SchemaViolation(
f"Variant '{name}': total_revenue={total_revenue} but "
f"DataFrame revenue sums to {actual_revenue}"
)
def _check_row_variant_name(variant_name: str, df: pd.DataFrame) -> None:
"""Check that every row's variant column matches VariantData.name."""
mismatched = df["variant"] != variant_name
if mismatched.any():
n_bad = int(mismatched.sum())
bad_values = sorted(df.loc[mismatched, "variant"].unique().tolist())
raise SchemaViolation(
f"Variant '{variant_name}': {n_bad} rows have mismatched variant "
f"column values: {bad_values}"
)
def _check_row_experiment_id(
variant_name: str, df: pd.DataFrame, expected_id: str
) -> None:
"""Check that every row's experiment_id matches the experiment."""
mismatched = df["experiment_id"] != expected_id
if mismatched.any():
n_bad = int(mismatched.sum())
bad_values = sorted(
df.loc[mismatched, "experiment_id"].unique().tolist()
)
raise SchemaViolation(
f"Variant '{variant_name}': {n_bad} rows have experiment_id "
f"values {bad_values} instead of '{expected_id}'"
)
def _check_global_visitor_id_uniqueness(data: ObservedExperimentData) -> None:
"""Check no visitor_id appears in more than one variant."""
seen: dict[str, str] = {} # visitor_id -> variant_name
for variant in data.variants:
ids = variant.visitors["visitor_id"]
for vid in ids:
if vid in seen:
raise SchemaViolation(
f"visitor_id '{vid}' appears in both variant "
f"'{seen[vid]}' and variant '{variant.name}'"
)
seen[vid] = variant.name
def _check_feature_column_consistency(data: ObservedExperimentData) -> None:
"""Check that all variants have the same extra columns with same dtypes.
Only runs in strict mode. Extra columns are those beyond VISITOR_SCHEMA.
"""
schema_cols = set(VISITOR_SCHEMA)
variant_extras: dict[str, dict[str, Any]] = {}
for variant in data.variants:
extras = {
col: variant.visitors[col].dtype
for col in variant.visitors.columns
if col not in schema_cols
}
variant_extras[variant.name] = extras
# Check column sets are identical.
all_col_sets = {
name: frozenset(extras.keys())
for name, extras in variant_extras.items()
}
unique_sets = set(all_col_sets.values())
if len(unique_sets) > 1:
# Report diff between first variant and each differing one.
ref_name = data.variants[0].name
ref_cols = all_col_sets[ref_name]
problems = []
for name, cols in all_col_sets.items():
if cols != ref_cols:
only_ref = sorted(ref_cols - cols)
only_other = sorted(cols - ref_cols)
parts = []
if only_ref:
parts.append(f"missing {only_ref}")
if only_other:
parts.append(f"extra {only_other}")
problems.append(f"'{name}': {', '.join(parts)}")
raise SchemaViolation(
f"Feature columns differ across variants "
f"(relative to '{ref_name}'): {'; '.join(problems)}"
)
# Check dtypes are consistent across variants.
if len(data.variants) < 2:
return
ref_name = data.variants[0].name
ref_extras = variant_extras[ref_name]
for variant in data.variants[1:]:
their_extras = variant_extras[variant.name]
for col, ref_dtype in ref_extras.items():
their_dtype = their_extras[col]
if ref_dtype != their_dtype:
raise SchemaViolation(
f"Feature column '{col}' dtype mismatch: "
f"'{ref_name}' has {ref_dtype}, "
f"'{variant.name}' has {their_dtype}"
)
def _is_numeric_dtype(dtype: np.dtype) -> bool:
"""Check if a numpy dtype is numeric (int or float)."""
return dtype.kind in ("i", "u", "f")
def _check_clause(
clause: RuleClause,
combined_dtypes: dict[str, np.dtype],
) -> None:
"""Validate a single rule clause against available columns."""
feature = clause.feature
if feature not in combined_dtypes:
raise RuleViolation(
f"Rule clause references feature '{feature}' which does not "
f"exist in visitor data. Available: {sorted(combined_dtypes)}"
)
dtype = combined_dtypes[feature]
if isinstance(clause, (ComparisonRule, BetweenRule)):
# Numeric rules require numeric columns.
if not _is_numeric_dtype(dtype):
raise RuleViolation(
f"Rule clause {type(clause).__name__} on feature '{feature}' "
f"requires a numeric column, but dtype is {dtype}"
)
elif isinstance(clause, (EqRule, InRule)):
# Categorical rules require non-numeric columns.
if _is_numeric_dtype(dtype):
raise RuleViolation(
f"Rule clause {type(clause).__name__} on feature '{feature}' "
f"requires a categorical column, but dtype is {dtype} "
f"(numeric)"
)