Source code for pytyche.experiment.schedule

"""Schedule protocol + shipped variants for per-round visitor counts.

A schedule answers one question per round: how many visitors does round
``rounds_completed`` get? ``None`` signals stop. ``GeometricSchedule`` is
the doubling-batch default (Perchet 2016; Esfandiari 2021; Che & Namkoong
2023); ``FixedSchedule`` and ``ExplicitSchedule`` cover flat and
fully-custom plans. Custom implementations only need the two protocol
members — no shipped base class is required.
"""

from __future__ import annotations

import dataclasses
from collections.abc import Sequence
from typing import Protocol, runtime_checkable

__all__ = [
    "ExplicitSchedule",
    "FixedSchedule",
    "GeometricSchedule",
    "Schedule",
]


[docs] @runtime_checkable class Schedule(Protocol): """Per-round visitor-count contract consumed by ``SequentialExperiment``.""" @property def n_rounds(self) -> int | None: ... def next_round_size(self, rounds_completed: int) -> int | None: ...
[docs] @dataclasses.dataclass(frozen=True) class GeometricSchedule: """Doubling-batch schedule: round ``i`` gets ``initial * growth**i`` visitors. ``n_rounds=None`` is open-ended — the operator stops the loop. Non-integer products round to the nearest int. """ initial: int growth: float = 2.0 n_rounds: int | None = None def __post_init__(self) -> None: if self.initial <= 0: raise ValueError(f"initial must be positive, got {self.initial}") if self.growth <= 0: raise ValueError(f"growth must be positive, got {self.growth}") def next_round_size(self, rounds_completed: int) -> int | None: if self.n_rounds is not None and rounds_completed >= self.n_rounds: return None return int(round(self.initial * self.growth**rounds_completed)) def __repr__(self) -> str: bound = f"n_rounds={self.n_rounds}" if self.n_rounds is not None else "open-ended" return f"GeometricSchedule({self.initial:,} × {self.growth}^round, {bound})"
[docs] @dataclasses.dataclass(frozen=True) class FixedSchedule: """Flat per-round visitor count for a bounded number of rounds.""" per_round: int n_rounds: int def __post_init__(self) -> None: if self.per_round <= 0: raise ValueError(f"per_round must be positive, got {self.per_round}") if self.n_rounds <= 0: raise ValueError(f"n_rounds must be positive, got {self.n_rounds}") def next_round_size(self, rounds_completed: int) -> int | None: if rounds_completed >= self.n_rounds: return None return self.per_round def __repr__(self) -> str: return f"FixedSchedule({self.per_round:,}/round × {self.n_rounds})"
[docs] @dataclasses.dataclass(frozen=True) class ExplicitSchedule: """User-supplied per-round visitor counts; one round per entry.""" sizes: Sequence[int] def __post_init__(self) -> None: if not self.sizes: raise ValueError("sizes must be non-empty") if any(s <= 0 for s in self.sizes): raise ValueError(f"all sizes must be positive, got {list(self.sizes)}") object.__setattr__(self, "sizes", tuple(self.sizes)) @property def n_rounds(self) -> int: return len(self.sizes) def next_round_size(self, rounds_completed: int) -> int | None: if rounds_completed >= len(self.sizes): return None return self.sizes[rounds_completed] def __repr__(self) -> str: return f"ExplicitSchedule({[f'{s:,}' for s in self.sizes]})"