Source code for pytyche.experiment.schedule
"""Schedule protocol + shipped variants for per-round visitor counts.
A schedule answers one question per round: how many visitors does round
``rounds_completed`` get? ``None`` signals stop. ``GeometricSchedule`` is
the doubling-batch default (Perchet 2016; Esfandiari 2021; Che & Namkoong
2023); ``FixedSchedule`` and ``ExplicitSchedule`` cover flat and
fully-custom plans. Custom implementations only need the two protocol
members — no shipped base class is required.
"""
from __future__ import annotations
import dataclasses
from collections.abc import Sequence
from typing import Protocol, runtime_checkable
__all__ = [
"ExplicitSchedule",
"FixedSchedule",
"GeometricSchedule",
"Schedule",
]
[docs]
@runtime_checkable
class Schedule(Protocol):
"""Per-round visitor-count contract consumed by ``SequentialExperiment``."""
@property
def n_rounds(self) -> int | None: ...
def next_round_size(self, rounds_completed: int) -> int | None: ...
[docs]
@dataclasses.dataclass(frozen=True)
class GeometricSchedule:
"""Doubling-batch schedule: round ``i`` gets ``initial * growth**i`` visitors.
``n_rounds=None`` is open-ended — the operator stops the loop.
Non-integer products round to the nearest int.
"""
initial: int
growth: float = 2.0
n_rounds: int | None = None
def __post_init__(self) -> None:
if self.initial <= 0:
raise ValueError(f"initial must be positive, got {self.initial}")
if self.growth <= 0:
raise ValueError(f"growth must be positive, got {self.growth}")
def next_round_size(self, rounds_completed: int) -> int | None:
if self.n_rounds is not None and rounds_completed >= self.n_rounds:
return None
return int(round(self.initial * self.growth**rounds_completed))
def __repr__(self) -> str:
bound = f"n_rounds={self.n_rounds}" if self.n_rounds is not None else "open-ended"
return f"GeometricSchedule({self.initial:,} × {self.growth}^round, {bound})"
[docs]
@dataclasses.dataclass(frozen=True)
class FixedSchedule:
"""Flat per-round visitor count for a bounded number of rounds."""
per_round: int
n_rounds: int
def __post_init__(self) -> None:
if self.per_round <= 0:
raise ValueError(f"per_round must be positive, got {self.per_round}")
if self.n_rounds <= 0:
raise ValueError(f"n_rounds must be positive, got {self.n_rounds}")
def next_round_size(self, rounds_completed: int) -> int | None:
if rounds_completed >= self.n_rounds:
return None
return self.per_round
def __repr__(self) -> str:
return f"FixedSchedule({self.per_round:,}/round × {self.n_rounds})"
[docs]
@dataclasses.dataclass(frozen=True)
class ExplicitSchedule:
"""User-supplied per-round visitor counts; one round per entry."""
sizes: Sequence[int]
def __post_init__(self) -> None:
if not self.sizes:
raise ValueError("sizes must be non-empty")
if any(s <= 0 for s in self.sizes):
raise ValueError(f"all sizes must be positive, got {list(self.sizes)}")
object.__setattr__(self, "sizes", tuple(self.sizes))
@property
def n_rounds(self) -> int:
return len(self.sizes)
def next_round_size(self, rounds_completed: int) -> int | None:
if rounds_completed >= len(self.sizes):
return None
return self.sizes[rounds_completed]
def __repr__(self) -> str:
return f"ExplicitSchedule({[f'{s:,}' for s in self.sizes]})"