Source code for pytyche.validation

"""Runtime validators for v2 contracts.

These validators make the contract types executable — they enforce
invariants at construction time rather than discovering violations
downstream.

Used by both generators and loaders. Fail-closed: no silent acceptance
of malformed data.

Functions:

- ``validate_observed_data(data)`` — checks visitor schema, dtypes,
  invariants.
- ``validate_alignment(array, data)`` — confirms array length matches
  visitors.
- ``validate_rule(rule, data)`` — confirms rule features exist and are
  type-compatible.
"""

from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd

from pytyche.contracts import (
    VISITOR_SCHEMA,
    AlignedVisitorArray,
    BetweenRule,
    ComparisonRule,
    EqRule,
    InRule,
    ObservedExperimentData,
    RuleClause,
    SegmentRule,
    VariantData,
)


[docs] class SchemaViolation(Exception): """Raised when observed data violates the visitor schema contract."""
[docs] class AlignmentViolation(Exception): """Raised when a per-visitor array is misaligned with visitor rows."""
[docs] class RuleViolation(Exception): """Raised when a segment rule references invalid features or types."""
[docs] def validate_observed_data( data: ObservedExperimentData, *, strict: bool = True, ) -> None: """Validate that all variant DataFrames conform to ``VISITOR_SCHEMA``. Per-variant checks: 1. All required columns are present. 2. Column dtypes are compatible with the schema. 3. ``revenue >= 0`` for all rows. 4. No duplicate ``visitor_id`` within a variant. 5. ``n_visitors``, ``n_conversions``, ``total_revenue`` match the DataFrame contents. 6. Every row's ``variant`` column matches ``VariantData.name``. 7. Every row's ``experiment_id`` matches ``data.experiment_id``. Cross-variant checks: 8. No ``visitor_id`` appears in more than one variant (a visitor can only be assigned to one arm). Strict-mode checks (``strict=True``, the default): 9. All variants have the same set of extra feature columns (beyond ``VISITOR_SCHEMA``). 10. Extra feature columns have consistent dtypes across variants. Set ``strict=False`` when feature-column asymmetry across variants is intentional. Example: a treatment arm collects an extra survey response column that doesn't exist for control:: # Treatment adds a post-checkout "why did you buy?" column. # Control visitors never see the survey, so the column is absent. validate_observed_data(data, strict=False) Args: data: The observed experiment data to validate. strict: If ``True`` (default), require feature-column consistency across variants. If ``False``, skip cross-variant column checks. Raises: SchemaViolation: On any violation, with a message identifying the variant and the specific problem. """ for variant in data.variants: df = variant.visitors _check_required_columns(variant.name, df) _check_dtypes(variant.name, df) _check_revenue_non_negative(variant.name, df) _check_unique_visitor_ids(variant.name, df) _check_summary_consistency(variant) _check_row_variant_name(variant.name, df) _check_row_experiment_id(variant.name, df, data.experiment_id) _check_global_visitor_id_uniqueness(data) if strict: _check_feature_column_consistency(data)
[docs] def validate_alignment( array: AlignedVisitorArray, data: ObservedExperimentData, ) -> None: """Confirm that a per-visitor array is aligned with concatenated visitors. The expected length is the sum of ``n_visitors`` across all variants, which equals the row count of:: pd.concat([v.visitors for v in data.variants], ignore_index=True) Args: array: The aligned array to check. data: The experiment data providing the visitor count. Raises: AlignmentViolation: If the array length doesn't match. """ total_visitors = sum(v.n_visitors for v in data.variants) if array.n_visitors != total_visitors: raise AlignmentViolation( f"AlignedVisitorArray has n_visitors={array.n_visitors}, " f"but experiment has {total_visitors} total visitors " f"across {len(data.variants)} variants" )
[docs] def validate_rule( rule: SegmentRule, data: ObservedExperimentData, ) -> None: """Confirm that a segment rule's features exist in visitor data and have compatible types. Checks: 1. Each clause's ``feature`` column exists in at least one variant's DataFrame. 2. Numeric rules (``ComparisonRule``, ``BetweenRule``) reference numeric columns. 3. Categorical rules (``EqRule``, ``InRule``) reference non-numeric columns. Does NOT enforce allowed categorical values — no domain registry in Phase 1 scope. Args: rule: The segment rule to validate. data: The experiment data providing the column schema. Raises: RuleViolation: If a feature is missing or type-incompatible. """ # Build a combined column→dtype map from all variants. combined_dtypes: dict[str, Any] = {} for variant in data.variants: for col in variant.visitors.columns: if col not in combined_dtypes: combined_dtypes[col] = variant.visitors[col].dtype for clause in rule.clauses: _check_clause(clause, combined_dtypes)
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _check_required_columns(variant_name: str, df: pd.DataFrame) -> None: """Check that all VISITOR_SCHEMA columns are present.""" missing = set(VISITOR_SCHEMA) - set(df.columns) if missing: raise SchemaViolation( f"Variant '{variant_name}': missing required columns: " f"{sorted(missing)}" ) def _check_dtypes(variant_name: str, df: pd.DataFrame) -> None: """Check that required columns have compatible dtypes.""" for col, expected_dtype_str in VISITOR_SCHEMA.items(): actual = df[col].dtype expected = np.dtype(expected_dtype_str) # Allow compatible types: int32 for int64, float32 for float64, etc. if expected.kind == "f": # Float columns: accept any float kind. if actual.kind != "f": raise SchemaViolation( f"Variant '{variant_name}': column '{col}' should be " f"float, got {actual}" ) elif expected.kind == "i": # Integer columns: accept any integer kind (signed). if actual.kind not in ("i", "u"): raise SchemaViolation( f"Variant '{variant_name}': column '{col}' should be " f"integer, got {actual}" ) elif expected.kind == "b": # Boolean columns. if actual.kind != "b": raise SchemaViolation( f"Variant '{variant_name}': column '{col}' should be " f"bool, got {actual}" ) elif expected.kind == "O": # Object columns: accept object or string kind. if actual.kind not in ("O", "U", "S"): raise SchemaViolation( f"Variant '{variant_name}': column '{col}' should be " f"string/object, got {actual}" ) def _check_revenue_non_negative(variant_name: str, df: pd.DataFrame) -> None: """Check that revenue >= 0 for all rows.""" if (df["revenue"] < 0).any(): n_neg = int((df["revenue"] < 0).sum()) raise SchemaViolation( f"Variant '{variant_name}': {n_neg} rows have negative revenue" ) def _check_unique_visitor_ids(variant_name: str, df: pd.DataFrame) -> None: """Check no duplicate visitor_id within a variant.""" if df["visitor_id"].duplicated().any(): n_dup = int(df["visitor_id"].duplicated().sum()) raise SchemaViolation( f"Variant '{variant_name}': {n_dup} duplicate visitor_id values" ) def _check_summary_consistency(variant: VariantData) -> None: """Check that summary fields match DataFrame contents.""" name = variant.name df = variant.visitors n_conversions = variant.n_conversions total_revenue = variant.total_revenue actual_conversions = int(df["converted"].sum().item()) if actual_conversions != n_conversions: raise SchemaViolation( f"Variant '{name}': n_conversions={n_conversions} but " f"DataFrame has {actual_conversions} converted rows" ) actual_revenue = float(df["revenue"].sum().item()) # Float comparison with tolerance for summation rounding. if abs(actual_revenue - total_revenue) > 1e-6: raise SchemaViolation( f"Variant '{name}': total_revenue={total_revenue} but " f"DataFrame revenue sums to {actual_revenue}" ) def _check_row_variant_name(variant_name: str, df: pd.DataFrame) -> None: """Check that every row's variant column matches VariantData.name.""" mismatched = df["variant"] != variant_name if mismatched.any(): n_bad = int(mismatched.sum()) bad_values = sorted(df.loc[mismatched, "variant"].unique().tolist()) raise SchemaViolation( f"Variant '{variant_name}': {n_bad} rows have mismatched variant " f"column values: {bad_values}" ) def _check_row_experiment_id( variant_name: str, df: pd.DataFrame, expected_id: str ) -> None: """Check that every row's experiment_id matches the experiment.""" mismatched = df["experiment_id"] != expected_id if mismatched.any(): n_bad = int(mismatched.sum()) bad_values = sorted( df.loc[mismatched, "experiment_id"].unique().tolist() ) raise SchemaViolation( f"Variant '{variant_name}': {n_bad} rows have experiment_id " f"values {bad_values} instead of '{expected_id}'" ) def _check_global_visitor_id_uniqueness(data: ObservedExperimentData) -> None: """Check no visitor_id appears in more than one variant.""" seen: dict[str, str] = {} # visitor_id -> variant_name for variant in data.variants: ids = variant.visitors["visitor_id"] for vid in ids: if vid in seen: raise SchemaViolation( f"visitor_id '{vid}' appears in both variant " f"'{seen[vid]}' and variant '{variant.name}'" ) seen[vid] = variant.name def _check_feature_column_consistency(data: ObservedExperimentData) -> None: """Check that all variants have the same extra columns with same dtypes. Only runs in strict mode. Extra columns are those beyond VISITOR_SCHEMA. """ schema_cols = set(VISITOR_SCHEMA) variant_extras: dict[str, dict[str, Any]] = {} for variant in data.variants: extras = { col: variant.visitors[col].dtype for col in variant.visitors.columns if col not in schema_cols } variant_extras[variant.name] = extras # Check column sets are identical. all_col_sets = { name: frozenset(extras.keys()) for name, extras in variant_extras.items() } unique_sets = set(all_col_sets.values()) if len(unique_sets) > 1: # Report diff between first variant and each differing one. ref_name = data.variants[0].name ref_cols = all_col_sets[ref_name] problems = [] for name, cols in all_col_sets.items(): if cols != ref_cols: only_ref = sorted(ref_cols - cols) only_other = sorted(cols - ref_cols) parts = [] if only_ref: parts.append(f"missing {only_ref}") if only_other: parts.append(f"extra {only_other}") problems.append(f"'{name}': {', '.join(parts)}") raise SchemaViolation( f"Feature columns differ across variants " f"(relative to '{ref_name}'): {'; '.join(problems)}" ) # Check dtypes are consistent across variants. if len(data.variants) < 2: return ref_name = data.variants[0].name ref_extras = variant_extras[ref_name] for variant in data.variants[1:]: their_extras = variant_extras[variant.name] for col, ref_dtype in ref_extras.items(): their_dtype = their_extras[col] if ref_dtype != their_dtype: raise SchemaViolation( f"Feature column '{col}' dtype mismatch: " f"'{ref_name}' has {ref_dtype}, " f"'{variant.name}' has {their_dtype}" ) def _is_numeric_dtype(dtype: np.dtype) -> bool: """Check if a numpy dtype is numeric (int or float).""" return dtype.kind in ("i", "u", "f") def _check_clause( clause: RuleClause, combined_dtypes: dict[str, np.dtype], ) -> None: """Validate a single rule clause against available columns.""" feature = clause.feature if feature not in combined_dtypes: raise RuleViolation( f"Rule clause references feature '{feature}' which does not " f"exist in visitor data. Available: {sorted(combined_dtypes)}" ) dtype = combined_dtypes[feature] if isinstance(clause, (ComparisonRule, BetweenRule)): # Numeric rules require numeric columns. if not _is_numeric_dtype(dtype): raise RuleViolation( f"Rule clause {type(clause).__name__} on feature '{feature}' " f"requires a numeric column, but dtype is {dtype}" ) elif isinstance(clause, (EqRule, InRule)): # Categorical rules require non-numeric columns. if _is_numeric_dtype(dtype): raise RuleViolation( f"Rule clause {type(clause).__name__} on feature '{feature}' " f"requires a categorical column, but dtype is {dtype} " f"(numeric)" )