"""Configuration dataclass, result types, and small utilities for the GPU BCF.
This module holds the user-facing configuration object (``GPUBCFConfig``), the
three result dataclasses returned by the ``fit_*`` entry points, the
formula-driven ``compute_num_trees_tau`` helper, and the leaf-index dtype
selector used to size heap-layout tree arrays. Pure types and utilities — no
JIT-compiled code, no GPU device handles, no module-level state. Importing
this module is cheap and triggers no GPU work.
Import graph
------------
``gpu_bcf_config`` depends on JAX (for the leaf-index dtype helper),
numpy (for result-array typing), and scipy.stats (for the inverse-normal
quantile in ``compute_num_trees_tau``). It does not import from any sibling
``gpu_bcf_*`` module. The orchestrator and downstream modules import FROM
here, never the other way around.
Contents
--------
``_leaf_index_dtype`` — smallest unsigned int dtype for heap node indices at a given tree depth.
``GPUBCFConfig`` — frozen dataclass of MCMC and prior hyperparameters for the GPU BCF.
``compute_num_trees_tau`` — formula for the minimum tau-forest tree count at a target CI coverage.
``ContinuousBCFResult`` — result container for ``fit_continuous_bcf``.
``BinaryBCFResult`` — result container for ``fit_binary_bcf``.
``HurdleBCFResult`` — result container for ``fit_hurdle_bcf``.
"""
from __future__ import annotations
import dataclasses
import math
from typing import TYPE_CHECKING, Any, Literal
import jax.numpy as jnp
import numpy as np
from scipy.stats import norm as sp_norm
if TYPE_CHECKING:
from collections.abc import Sequence
from pytyche.analysis._policy_tree import PolicyTreeResult
from pytyche.analysis._truth import TruthComparison
from pytyche.bcf.diagnostics.topology import TopologyHistory
from pytyche.calibrate.artifact import Calibration
from pytyche.contracts import (
AnalysisResult,
CalibrationTruth,
DecisionThresholds,
DiscoveredSegment,
ObservedExperimentData,
RecommendationSummary,
)
# ---------------------------------------------------------------------------
# Leaf index dtype selection (VRAM optimization)
# ---------------------------------------------------------------------------
def _leaf_index_dtype(max_depth: int) -> jnp.dtype:
"""Smallest unsigned int dtype for heap node indices at given tree depth.
bartz uses 1-indexed heap layout (root=1, left=2i, right=2i+1).
Max node ID at depth d is 2^d - 1 (e.g., 63 at depth 6).
"""
max_node = 2 ** max_depth - 1
if max_node <= 255: # max_depth <= 7
return jnp.uint8
if max_node <= 65535: # max_depth <= 15
return jnp.uint16
return jnp.int32
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class GPUBCFConfig:
"""Sampling configuration for GPU BCF via bartz.
Attributes
----------
num_burnin:
Number of MCMC burn-in iterations (discarded).
num_mcmc:
Number of MCMC samples to retain for posterior inference.
num_trees_mu:
Number of trees in the prognostic (mu) forest.
num_trees_tau:
Number of trees in the treatment effect (tau) forest.
max_depth:
Maximum tree depth (controls p_nonterminal array length).
alpha_mu / beta_mu:
Tree prior hyperparameters for mu forest.
alpha_tau / beta_tau:
Tree prior hyperparameters for tau forest (tighter = more regularized).
num_cuts:
Number of quantile-based split cutpoints per covariate.
random_seed:
Seed for JAX PRNG.
num_chains:
Number of parallel MCMC chains (vmapped). 1 = single-chain (legacy).
diagnostic_interval:
Iterations per chunk for between-chunk diagnostics. Must divide both
num_burnin and num_mcmc evenly.
thin_factor:
Keep every thin_factor-th sample during MCMC (1 = keep all).
retain_topology_history:
If True, retain per-iter per-tree topology hashes and move metadata on
``HurdleBCFResult.topology_history`` for mobility diagnostics. Default
False (no retention, byte-identical to pre-feature behaviour).
"""
num_burnin: int = 200
num_mcmc: int = 200
num_trees_mu: int = 200
num_trees_tau: int = 50
max_depth: int = 6
alpha_mu: float = 0.95
beta_mu: float = 2.0
alpha_tau: float = 0.75
beta_tau: float = 3.0
num_cuts: int = 100
random_seed: int = 42
num_chains: int = 1
diagnostic_interval: int = 50
thin_factor: int = 1
num_gfr_sweeps: int = 5
min_samples_leaf: int = 5
gfr_backend: str = "gpu" # grow-from-root (GFR) warm start: "gpu" = JAX, "cpu" = StochTree
trace_path: str | None = None # If set, write raw channel .npz per chunk
# Heteroscedastic severity (per-leaf precision) hyperparameters
var_tau_sev: float = 0.5 # Signal-to-noise for gamma prior moment-matching (Linero default)
kappa_sev: float = 1.0 # Normal-Gamma prior precision scaling for beta
tau0_a_prior: float = 1.0 # Global precision shape prior
tau0_b_prior: float = 1.0 # Global precision rate prior
freeze_gamma: bool = False # Debug: hold gamma at init, skip recentering
retain_channel_samples: bool = True # Retain per-draw channel arrays on CPU (the conversion/severity decomposition needs them; set False to skip the transfer at sweep scale)
focal_severity: bool = False # Upweight severity LML by n/n_conv so each converter contributes equally to splits
# When True, severity channel uses per-leaf Normal-Gamma sampling with a
# multiplicative tau_hat update. When False (default), severity uses the
# parallel-LML Normal-Normal path mirroring bartz; tau_hat stays at 1.0
# and the global tau_0 absorbs all severity precision. The False path is
# ~3-5x faster per iteration on GPU; the True path is preserved for
# backward compatibility and rarely needed in practice.
per_leaf_gamma: bool = False
# Retain per-iter per-tree topology hashes + move metadata across the full
# burn-in + MCMC trace. When True, ``HurdleBCFResult.topology_history`` is
# populated with a ``TopologyHistory`` instance for downstream mobility
# diagnostics; when False (default), it stays ``None`` and the fit's
# wall-clock + PRNG behaviour is bitwise-identical to HEAD pre-this-change.
# Memory cost: ~14 bytes per (chain, iter, tree) entry — ~4 MB at the
# production default of 4 chains x 300 iters x 250 trees, ~140 MB at
# 10000 iters with the same shape. Opt-in for predictability.
retain_topology_history: bool = False
[docs]
def compute_num_trees_tau(
n: int,
d_tau: float = 3.0,
sigma_tau: float = 0.5,
coverage: float = 0.90,
floor: int = 50,
ceiling: int = 400,
) -> int:
"""Formula-driven tau tree count for target CI coverage.
T_min = ceil((d_tau * sigma_tau * sqrt(n) / (2 * z))^(2/3))
The tau forest's piecewise-constant approximation has O(1) bias that
dominates the posterior at large n (O(1/sqrt(n)) concentration). This
formula computes the minimum T to keep bias below the CI half-width
at the target coverage level.
Parameters
----------
n : cumulative sample size
d_tau : effective CATE dimensionality (treatment effect modifiers)
sigma_tau : CATE heterogeneity on standardized/probit scale
coverage : target CI coverage level (0.90 = 90%)
floor : minimum tree count (BCF needs some for variable selection)
ceiling : maximum tree count (VRAM bound)
"""
z = sp_norm.ppf(1 - (1 - coverage) / 2)
t_min = (d_tau * sigma_tau * math.sqrt(n) / (2 * z)) ** (2 / 3)
return int(min(max(math.ceil(t_min), floor), ceiling))
# ---------------------------------------------------------------------------
# Result containers
# ---------------------------------------------------------------------------
[docs]
@dataclasses.dataclass(frozen=True)
class ContinuousBCFResult:
"""Result from continuous BCF.
Attributes
----------
mu_samples:
``(n, num_mcmc)`` prognostic predictions (standardized).
tau_samples:
``(n, num_mcmc)`` treatment effects (standardized).
sigma2_samples:
``(num_mcmc,)`` error variance (standardized).
y_bar:
Mean of the outcome used for standardization.
y_std:
Standard deviation of the outcome used for standardization.
wall_clock_seconds:
Wall-clock time for the fit in seconds.
observed:
The ``ObservedExperimentData`` the fit consumed, attached to the
result so the analysis methods can reach the visitor rows and
variant metadata. ``None`` when constructed by private raw-array
helpers; populated by the public fit wrappers.
is_calibrated:
``True`` only after ``apply_calibration`` has been called on this
result. Defaults to ``False``.
calibration:
The ``Calibration`` artifact attached by ``apply_calibration``;
``None`` on fresh fits. The v0.2 artifact scope is interval
corrections only — it is consumed where interval summaries are
built, never to transform sample arrays.
"""
mu_samples: np.ndarray # (n, num_mcmc) prognostic predictions (standardized)
tau_samples: np.ndarray # (n, num_mcmc) treatment effects (standardized)
sigma2_samples: np.ndarray # (num_mcmc,) error variance (standardized)
y_bar: float
y_std: float
wall_clock_seconds: float
observed: ObservedExperimentData | None = dataclasses.field(default=None, kw_only=True)
is_calibrated: bool = dataclasses.field(default=False, kw_only=True)
calibration: Calibration | None = dataclasses.field(default=None, kw_only=True, repr=False)
[docs]
def thompson_allocation(
self,
segments: Sequence[DiscoveredSegment],
epsilon: float = 0.02,
) -> dict[int, dict[str, float]]:
"""Per-segment traffic split: each arm's weight is the posterior
probability that it is the segment's best arm.
Thompson sampling at segment granularity: per segment, each
posterior draw votes for its best arm (the largest member-mean
contrast, or control when none is positive); an arm's weight is
its win frequency over draws.
Args:
segments: Segments to allocate over (only ``id`` and ``rule``
are consumed); membership is resolved against
``self.observed``.
epsilon: Safety-net exploration floor — arms below
``epsilon / K`` are raised to the floor and the rest
rescaled, so no arm's traffic is starved to zero; inert
when every arm is already above it. NOT the
dial for how much traffic stays on control — that is
``min_control_weight`` / ``min_explore_weight`` on
``pt.sequential_experiment``; rarely worth overriding.
Returns:
``{segment.id: {variant_name: weight}}`` — inner dicts in
variant order (control first), each summing to 1.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._thompson import thompson_allocation as _impl
return _impl(self, segments=segments, epsilon=epsilon)
[docs]
def fit_policy_tree(
self,
*,
max_depth: int = 3,
min_segment_share: float = 0.10,
n_bootstrap: int = 50,
bootstrap_seed: int = 0,
) -> PolicyTreeResult:
"""Discover interpretable segments from the posterior's
per-visitor treatment effects, by fitting a shallow decision tree.
Each visitor is labeled with the arm the posterior expects to be
best for them (largest posterior-mean lift, or control when no
lift is positive); a multiclass decision tree is fit on the
visitors' features, and each leaf becomes a ``DiscoveredSegment``
carrying an exact membership rule, gate estimate/CI, per-arm best
probabilities, Thompson allocation, and bootstrap-replicability
stability.
Args:
max_depth: Maximum tree depth.
min_segment_share: Minimum fraction of visitors per leaf
(sklearn ``min_weight_fraction_leaf``).
n_bootstrap: Bootstrap tree refits behind ``stability_score``;
``0`` skips stability (NaN sentinel plus ``UserWarning``).
bootstrap_seed: Seed for the bootstrap resampling RNG.
Returns:
``PolicyTreeResult`` with one segment per leaf, ordered by
sklearn leaf id; ``result.observed`` is ``self.observed`` by
identity.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._policy_tree import fit_policy_tree as _impl
return _impl(
self,
max_depth=max_depth,
min_segment_share=min_segment_share,
n_bootstrap=n_bootstrap,
bootstrap_seed=bootstrap_seed,
)
[docs]
def apply_calibration(self, calibration: Calibration) -> ContinuousBCFResult:
"""Return a new posterior with *calibration* attached.
Attach, don't transform: the artifact is stashed on the returned
copy (``is_calibrated=True``); every sample array is shared with
this posterior by identity. The correction currently applies to
intervals only — probabilities and expected losses stay raw;
corrected CIs appear where interval summaries are built. K = 2
experiments only (per-contrast recalibration for K >= 3 is not
yet implemented).
Args:
calibration: SBC-fitted ``Calibration`` whose regime (metric,
n_treatments) must match ``self.observed``.
Returns:
New ``ContinuousBCFResult`` carrying the artifact; the
original is untouched.
Raises:
ValueError: When ``self.observed`` is ``None``, or on a
regime mismatch (message names the mismatched dimensions).
NotImplementedError: At K >= 3.
"""
from pytyche.analysis._calibrate import apply_calibration as _impl
return _impl(self, calibration)
[docs]
def recommendation_summary(
self,
treatment: str,
segment: DiscoveredSegment | None = None,
*,
thresholds: DecisionThresholds | None = None,
min_practical_effect: float = 0.02,
) -> RecommendationSummary:
"""Act-now SHIP / CONTINUE / STOP recommendation for one treatment.
The treatment's metric-native contrast draws are scoped
(``segment=None`` is the global all-visitors snapshot; a segment
restricts to its rule's members), reduced to per-draw mean lift,
and summarized under the legacy ``compare.variants`` decision
rule. v0.2 raw scope: probabilities and expected losses come from
the raw draws even on a calibrated posterior — interval
corrections land where intervals are built.
Args:
treatment: Treatment variant name (vs control).
segment: ``None`` for the global snapshot; a
``DiscoveredSegment`` restricts the computation to its
members.
thresholds: Decision thresholds; ``DecisionThresholds()``
defaults when ``None``.
min_practical_effect: Minimum meaningful lift for
``probability_better`` / ``probability_harmful``.
Returns:
``RecommendationSummary`` with the decision, its evidence,
and ``expected_value_of_one_more_round`` always populated
(closed-form preposterior EVSI; formula in
``docs/concepts/decision-theoretic-inputs.md``).
Raises:
ValueError: When ``self.observed`` is ``None``, when
*treatment* is not a treatment name, or when the
segment's rule matches zero visitors.
"""
from pytyche.analysis._recommendation import (
recommendation_summary as _impl,
)
return _impl(
self,
treatment,
segment=segment,
thresholds=thresholds,
min_practical_effect=min_practical_effect,
)
[docs]
def analyze(
self,
*,
max_depth: int = 3,
min_segment_share: float = 0.10,
n_bootstrap: int = 50,
bootstrap_seed: int = 0,
) -> AnalysisResult:
"""The canonical one-call analysis summary for this posterior.
Composes per-treatment ``Comparison`` summaries, the embedded
policy-tree segmentation (keyword arguments forward to it), the
global ``RecommendationSummary`` for the best challenger, and the
posterior-mean per-visitor CATEs. Anything needing posterior
samples goes through ``analysis.posterior``.
Args:
max_depth: Embedded policy tree depth.
min_segment_share: Minimum per-leaf population share.
n_bootstrap: Stability bootstrap count (``0`` skips stability
with a ``UserWarning``).
bootstrap_seed: Stability bootstrap seed.
Returns:
``AnalysisResult``; ``analysis.is_calibrated`` reads through
to this posterior's flag.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._analyze import analyze as _impl
return _impl(
self,
max_depth=max_depth,
min_segment_share=min_segment_share,
n_bootstrap=n_bootstrap,
bootstrap_seed=bootstrap_seed,
)
[docs]
def evaluate_against_truth(
self,
tree: PolicyTreeResult,
truth: CalibrationTruth | None,
) -> TruthComparison:
"""Sim-mode evaluation of *tree*'s policy against ground truth.
Args:
tree: The fitted policy whose assignments are evaluated.
truth: Ground truth from the simulation path; ``None`` in
real-data mode (raises — nothing to evaluate against).
Returns:
``TruthComparison`` (cate_rmse, policy_accuracy, and the
realized-RPV trio with the oracle gap).
Raises:
RuntimeError: When *truth* is ``None`` (real-data mode).
ValueError: When ``self.observed`` is ``None`` or the truth
lacks the K-appropriate contrast / potential-outcome
fields.
"""
from pytyche.analysis._truth import evaluate_against_truth as _impl
return _impl(self, tree=tree, truth=truth)
[docs]
def has_credible_segments(self, threshold: float = 0.80) -> bool:
"""Whether some discovered segment clears *threshold* stability.
Runs ``fit_policy_tree`` at its defaults (deterministic given the
default ``bootstrap_seed``) and checks for a segment with
``stability_score >= threshold``. The 0.80 default matches the
default graduation rule's SHIP-gate stability threshold.
Args:
threshold: Minimum bootstrap-replicability stability score.
Returns:
``True`` iff at least one discovered segment clears it.
"""
tree = self.fit_policy_tree()
return any(
score >= threshold for score in tree.stability_scores.values()
)
[docs]
def has_decomposition(self) -> bool:
"""Whether this posterior carries the conversion/severity split.
Returns:
``False`` — only hurdle posteriors carry the
conversion/severity decomposition.
"""
return False
[docs]
@dataclasses.dataclass(frozen=True)
class BinaryBCFResult:
"""Result from binary (probit) BCF.
Attributes
----------
mu_samples:
``(n, num_mcmc)`` prognostic predictions (probit scale).
tau_samples:
``(n, num_mcmc)`` treatment effects (probit scale).
wall_clock_seconds:
Wall-clock time for the fit in seconds.
observed:
The ``ObservedExperimentData`` the fit consumed, attached to the
result so the analysis methods can reach the visitor rows and
variant metadata. ``None`` when constructed by private raw-array
helpers; populated by the public fit wrappers.
is_calibrated:
``True`` only after ``apply_calibration`` has been called on this
result. Defaults to ``False``.
calibration:
The ``Calibration`` artifact attached by ``apply_calibration``;
``None`` on fresh fits. The v0.2 artifact scope is interval
corrections only — it is consumed where interval summaries are
built, never to transform sample arrays.
"""
mu_samples: np.ndarray # (n, num_mcmc) prognostic predictions (probit scale)
tau_samples: np.ndarray # (n, num_mcmc) treatment effects (probit scale)
wall_clock_seconds: float
observed: ObservedExperimentData | None = dataclasses.field(default=None, kw_only=True)
is_calibrated: bool = dataclasses.field(default=False, kw_only=True)
calibration: Calibration | None = dataclasses.field(default=None, kw_only=True, repr=False)
[docs]
def thompson_allocation(
self,
segments: Sequence[DiscoveredSegment],
epsilon: float = 0.02,
) -> dict[int, dict[str, float]]:
"""Per-segment traffic split: each arm's weight is the posterior
probability that it is the segment's best arm.
Thompson sampling at segment granularity: per segment, each
posterior draw votes for its best arm (the largest member-mean
contrast, or control when none is positive); an arm's weight is
its win frequency over draws.
Args:
segments: Segments to allocate over (only ``id`` and ``rule``
are consumed); membership is resolved against
``self.observed``.
epsilon: Safety-net exploration floor — arms below
``epsilon / K`` are raised to the floor and the rest
rescaled, so no arm's traffic is starved to zero; inert
when every arm is already above it. NOT the
dial for how much traffic stays on control — that is
``min_control_weight`` / ``min_explore_weight`` on
``pt.sequential_experiment``; rarely worth overriding.
Returns:
``{segment.id: {variant_name: weight}}`` — inner dicts in
variant order (control first), each summing to 1.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._thompson import thompson_allocation as _impl
return _impl(self, segments=segments, epsilon=epsilon)
[docs]
def fit_policy_tree(
self,
*,
max_depth: int = 3,
min_segment_share: float = 0.10,
n_bootstrap: int = 50,
bootstrap_seed: int = 0,
) -> PolicyTreeResult:
"""Discover interpretable segments from the posterior's
per-visitor treatment effects, by fitting a shallow decision tree.
Each visitor is labeled with the arm the posterior expects to be
best for them (largest posterior-mean lift, or control when no
lift is positive); a multiclass decision tree is fit on the
visitors' features, and each leaf becomes a ``DiscoveredSegment``
carrying an exact membership rule, gate estimate/CI, per-arm best
probabilities, Thompson allocation, and bootstrap-replicability
stability.
Args:
max_depth: Maximum tree depth.
min_segment_share: Minimum fraction of visitors per leaf
(sklearn ``min_weight_fraction_leaf``).
n_bootstrap: Bootstrap tree refits behind ``stability_score``;
``0`` skips stability (NaN sentinel plus ``UserWarning``).
bootstrap_seed: Seed for the bootstrap resampling RNG.
Returns:
``PolicyTreeResult`` with one segment per leaf, ordered by
sklearn leaf id; ``result.observed`` is ``self.observed`` by
identity.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._policy_tree import fit_policy_tree as _impl
return _impl(
self,
max_depth=max_depth,
min_segment_share=min_segment_share,
n_bootstrap=n_bootstrap,
bootstrap_seed=bootstrap_seed,
)
[docs]
def apply_calibration(self, calibration: Calibration) -> BinaryBCFResult:
"""Return a new posterior with *calibration* attached.
Attach, don't transform: the artifact is stashed on the returned
copy (``is_calibrated=True``); every sample array is shared with
this posterior by identity. The correction currently applies to
intervals only — probabilities and expected losses stay raw;
corrected CIs appear where interval summaries are built. K = 2
experiments only (per-contrast recalibration for K >= 3 is not
yet implemented).
Args:
calibration: SBC-fitted ``Calibration`` whose regime (metric,
n_treatments) must match ``self.observed``.
Returns:
New ``BinaryBCFResult`` carrying the artifact; the original
is untouched.
Raises:
ValueError: When ``self.observed`` is ``None``, or on a
regime mismatch (message names the mismatched dimensions).
NotImplementedError: At K >= 3.
"""
from pytyche.analysis._calibrate import apply_calibration as _impl
return _impl(self, calibration)
[docs]
def recommendation_summary(
self,
treatment: str,
segment: DiscoveredSegment | None = None,
*,
thresholds: DecisionThresholds | None = None,
min_practical_effect: float = 0.02,
) -> RecommendationSummary:
"""Act-now SHIP / CONTINUE / STOP recommendation for one treatment.
The treatment's metric-native contrast draws are scoped
(``segment=None`` is the global all-visitors snapshot; a segment
restricts to its rule's members), reduced to per-draw mean lift,
and summarized under the legacy ``compare.variants`` decision
rule. v0.2 raw scope: probabilities and expected losses come from
the raw draws even on a calibrated posterior — interval
corrections land where intervals are built.
Args:
treatment: Treatment variant name (vs control).
segment: ``None`` for the global snapshot; a
``DiscoveredSegment`` restricts the computation to its
members.
thresholds: Decision thresholds; ``DecisionThresholds()``
defaults when ``None``.
min_practical_effect: Minimum meaningful lift for
``probability_better`` / ``probability_harmful``.
Returns:
``RecommendationSummary`` with the decision, its evidence,
and ``expected_value_of_one_more_round`` always populated
(closed-form preposterior EVSI; formula in
``docs/concepts/decision-theoretic-inputs.md``).
Raises:
ValueError: When ``self.observed`` is ``None``, when
*treatment* is not a treatment name, or when the
segment's rule matches zero visitors.
"""
from pytyche.analysis._recommendation import (
recommendation_summary as _impl,
)
return _impl(
self,
treatment,
segment=segment,
thresholds=thresholds,
min_practical_effect=min_practical_effect,
)
[docs]
def analyze(
self,
*,
max_depth: int = 3,
min_segment_share: float = 0.10,
n_bootstrap: int = 50,
bootstrap_seed: int = 0,
) -> AnalysisResult:
"""The canonical one-call analysis summary for this posterior.
Composes per-treatment ``Comparison`` summaries, the embedded
policy-tree segmentation (keyword arguments forward to it), the
global ``RecommendationSummary`` for the best challenger, and the
posterior-mean per-visitor CATEs. Anything needing posterior
samples goes through ``analysis.posterior``.
Args:
max_depth: Embedded policy tree depth.
min_segment_share: Minimum per-leaf population share.
n_bootstrap: Stability bootstrap count (``0`` skips stability
with a ``UserWarning``).
bootstrap_seed: Stability bootstrap seed.
Returns:
``AnalysisResult``; ``analysis.is_calibrated`` reads through
to this posterior's flag.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._analyze import analyze as _impl
return _impl(
self,
max_depth=max_depth,
min_segment_share=min_segment_share,
n_bootstrap=n_bootstrap,
bootstrap_seed=bootstrap_seed,
)
[docs]
def evaluate_against_truth(
self,
tree: PolicyTreeResult,
truth: CalibrationTruth | None,
) -> TruthComparison:
"""Sim-mode evaluation of *tree*'s policy against ground truth.
Args:
tree: The fitted policy whose assignments are evaluated.
truth: Ground truth from the simulation path; ``None`` in
real-data mode (raises — nothing to evaluate against).
Returns:
``TruthComparison`` (cate_rmse, policy_accuracy, and the
realized-RPV trio with the oracle gap).
Raises:
RuntimeError: When *truth* is ``None`` (real-data mode).
ValueError: When ``self.observed`` is ``None`` or the truth
lacks the K-appropriate contrast / potential-outcome
fields.
"""
from pytyche.analysis._truth import evaluate_against_truth as _impl
return _impl(self, tree=tree, truth=truth)
[docs]
def has_credible_segments(self, threshold: float = 0.80) -> bool:
"""Whether some discovered segment clears *threshold* stability.
Runs ``fit_policy_tree`` at its defaults (deterministic given the
default ``bootstrap_seed``) and checks for a segment with
``stability_score >= threshold``. The 0.80 default matches the
default graduation rule's SHIP-gate stability threshold.
Args:
threshold: Minimum bootstrap-replicability stability score.
Returns:
``True`` iff at least one discovered segment clears it.
"""
tree = self.fit_policy_tree()
return any(
score >= threshold for score in tree.stability_scores.values()
)
[docs]
def has_decomposition(self) -> bool:
"""Whether this posterior carries the conversion/severity split.
Returns:
``False`` — only hurdle posteriors carry the
conversion/severity decomposition.
"""
return False
[docs]
@dataclasses.dataclass(frozen=True)
class HurdleBCFResult:
"""Result from joint shared-tree hurdle BCF.
Each tree simultaneously estimates conversion (probit) and severity
(log-revenue) parameters via shared tree structure. This couples the
two channels so splits are jointly informative.
RPV CATEs are composed on-GPU (float32) and transferred to CPU for
policy tree fitting. Channel-level per-draw arrays (p0, p1, sev0, sev1)
are retained by default (``retain_channel_samples=True``) — the
conversion/severity decomposition is the headline output of the hurdle
approach and needs the per-draw channel arrays for its credible
intervals. Set ``retain_channel_samples=False`` to skip the GPU→CPU
transfer when memory matters more than the decomposition (e.g. large-n
sweep contexts that only consume the composed RPV contrasts).
When num_chains > 1, samples are concatenated across chains:
S_total = (num_mcmc / thin_factor) * num_chains.
**Arm-count dispatch** (``K = int(Z.max()) + 1``). At K = 2 (binary arm) the
legacy paired fields are populated — ``p0_mean``/``p1_mean``/``sev0_mean``/
``sev1_mean`` ``(n,)`` and, when ``retain_channel_samples=True``,
``p0_samples``/``p1_samples``/``sev0_samples``/``sev1_samples`` ``(n, S_total)``
— ``rpv_cate_samples`` is ``(n, S_total)``, and the per-arm fields
``p_samples``/``sev_samples`` are ``None``. At K >= 3 (multi-arm) the per-arm
fields are populated instead — ``p_samples``/``sev_samples`` ``(n, S_total, K)``
(when retained) and ``rpv_cate_samples`` ``(n, S_total, K - 1)`` (the jointly
sampled contrast posterior) — and the legacy paired fields are ``None``. The
two field families are never populated together. ``tau0_samples`` ``(S_total,)``
and the ``sigma2_samples = 1 / tau0_samples`` property are scalar at every K
(each visitor sees one outcome, so the severity residual is scalar per visitor
— there is no per-arm severity precision).
The ``topology_history`` field is populated only when the producing fit
set ``GPUBCFConfig.retain_topology_history=True``. When the flag is off
(default), the field is ``None`` and the fit's wall-clock + PRNG state is
bitwise-identical to HEAD pre-this-change.
Attributes
----------
rpv_cate_samples:
``(n, S_total)`` float32 — composed on GPU, transferred to CPU.
p0_mean:
``(n,)`` float32 — E[Φ(μ_b + b₀·τ_b)]; None at K>=3.
p1_mean:
``(n,)`` float32 — E[Φ(μ_b + b₁·τ_b)]; None at K>=3.
sev0_mean:
``(n,)`` float32 — E[exp(μ_c + b₀·τ_c + σ²/2)]; None at K>=3.
sev1_mean:
``(n,)`` float32 — E[exp(μ_c + b₁·τ_c + σ²/2)]; None at K>=3.
tau0_samples:
``(S_total,)`` float32 — global precision.
tau_hat_quantiles:
``(S_total, 5)`` [q05,q25,q50,q75,q95] or None.
wall_clock_seconds:
Wall-clock time for the fit in seconds.
num_chains:
Number of parallel MCMC chains used.
num_gfr_sweeps:
Number of GFR warm-start sweeps performed.
diagnostics:
Dict of diagnostic values (rhat_tau0, per_chain_ess, etc.), or None.
phase_timing:
Dict of per-phase wall-clock breakdown, or None.
p0_samples:
jax.Array ``(n, S_total)`` — P(convert|control) per draw; None if not retained.
p1_samples:
jax.Array ``(n, S_total)`` — P(convert|treated) per draw; None if not retained.
sev0_samples:
jax.Array ``(n, S_total)`` — E[sev|control,convert] per draw; None if not retained.
sev1_samples:
jax.Array ``(n, S_total)`` — E[sev|treated,convert] per draw; None if not retained.
p_samples:
jax.Array ``(n, S_total, K)`` — per-arm P(convert) per draw; None at K=2.
sev_samples:
jax.Array ``(n, S_total, K)`` — per-arm E[sev|convert] per draw; None at K=2.
topology_history:
Topology retention trace; populated only when the producing fit set
``GPUBCFConfig.retain_topology_history=True``. ``None`` otherwise.
observed:
The ``ObservedExperimentData`` the fit consumed, attached to the
result so the analysis methods can reach the visitor rows and
variant metadata. ``None`` when constructed by private raw-array
helpers; populated by the public fit wrappers.
is_calibrated:
``True`` only after ``apply_calibration`` has been called on this
result. Defaults to ``False``.
calibration:
The ``Calibration`` artifact attached by ``apply_calibration``;
``None`` on fresh fits. The v0.2 artifact scope is interval
corrections only — it is consumed where interval summaries are
built, never to transform sample arrays.
pooling:
Provenance of the fit: ``"joint"`` = shared-tree canonical fit;
``"independent"`` = two-stage baseline (binary + continuous fitted
separately). Required — caller must always populate.
"""
rpv_cate_samples: np.ndarray # (n, S_total) float32 — composed on GPU, transferred to CPU
p0_mean: np.ndarray | None # (n,) float32 — E[Φ(μ_b + b₀·τ_b)]; None at K>=3
p1_mean: np.ndarray | None # (n,) float32 — E[Φ(μ_b + b₁·τ_b)]; None at K>=3
sev0_mean: np.ndarray | None # (n,) float32 — E[exp(μ_c + b₀·τ_c + σ²/2)]; None at K>=3
sev1_mean: np.ndarray | None # (n,) float32 — E[exp(μ_c + b₁·τ_c + σ²/2)]; None at K>=3
tau0_samples: np.ndarray # (S_total,) float32 — global precision
tau_hat_quantiles: np.ndarray | None # (S_total, 5) [q05,q25,q50,q75,q95] or None
wall_clock_seconds: float
num_chains: int = 1
num_gfr_sweeps: int = 0
diagnostics: dict | None = None # rhat_tau0, per_chain_ess, etc.
phase_timing: dict | None = None # per-phase wall-clock breakdown
# Per-draw channel arrays — JAX arrays on GPU, or None if not retained
p0_samples: Any | None = None # jax.Array (n, S_total) — P(convert|control) per draw
p1_samples: Any | None = None # jax.Array (n, S_total) — P(convert|treated) per draw
sev0_samples: Any | None = None # jax.Array (n, S_total) — E[sev|control,convert] per draw
sev1_samples: Any | None = None # jax.Array (n, S_total) — E[sev|treated,convert] per draw
p_samples: Any | None = None # jax.Array (n, S_total, K) — per-arm P(convert) per draw; None at K=2
sev_samples: Any | None = None # jax.Array (n, S_total, K) — per-arm E[sev|convert] per draw; None at K=2
# Optional retained topology trace, populated only when the producing fit
# set ``GPUBCFConfig.retain_topology_history=True``. ``None`` otherwise.
topology_history: "TopologyHistory | None" = None # noqa: UP037 — string form pinned by RESULT_FIELDS_SNAPSHOTS contract
observed: ObservedExperimentData | None = dataclasses.field(default=None, kw_only=True)
is_calibrated: bool = dataclasses.field(default=False, kw_only=True)
calibration: Calibration | None = dataclasses.field(default=None, kw_only=True, repr=False)
pooling: Literal["joint", "independent"] = dataclasses.field(kw_only=True)
@property
def sigma2_samples(self) -> np.ndarray:
"""Return ``1 / tau0_samples`` as a sigma² view.
Backward-compat shim for downstream code that consumes the
variance parameterisation rather than the precision one.
"""
return 1.0 / self.tau0_samples
[docs]
def thompson_allocation(
self,
segments: Sequence[DiscoveredSegment],
epsilon: float = 0.02,
) -> dict[int, dict[str, float]]:
"""Per-segment traffic split: each arm's weight is the posterior
probability that it is the segment's best arm.
Thompson sampling at segment granularity: per segment, each
posterior draw votes for its best arm (the largest member-mean
contrast, or control when none is positive); an arm's weight is
its win frequency over draws.
Args:
segments: Segments to allocate over (only ``id`` and ``rule``
are consumed); membership is resolved against
``self.observed``.
epsilon: Safety-net exploration floor — arms below
``epsilon / K`` are raised to the floor and the rest
rescaled, so no arm's traffic is starved to zero; inert
when every arm is already above it. NOT the
dial for how much traffic stays on control — that is
``min_control_weight`` / ``min_explore_weight`` on
``pt.sequential_experiment``; rarely worth overriding.
Returns:
``{segment.id: {variant_name: weight}}`` — inner dicts in
variant order (control first), each summing to 1.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._thompson import thompson_allocation as _impl
return _impl(self, segments=segments, epsilon=epsilon)
[docs]
def fit_policy_tree(
self,
*,
max_depth: int = 3,
min_segment_share: float = 0.10,
n_bootstrap: int = 50,
bootstrap_seed: int = 0,
) -> PolicyTreeResult:
"""Discover interpretable segments from the posterior's
per-visitor treatment effects, by fitting a shallow decision tree.
Each visitor is labeled with the arm the posterior expects to be
best for them (largest posterior-mean lift, or control when no
lift is positive); a multiclass decision tree is fit on the
visitors' features, and each leaf becomes a ``DiscoveredSegment``
carrying an exact membership rule, gate estimate/CI, per-arm best
probabilities, Thompson allocation, and bootstrap-replicability
stability.
Args:
max_depth: Maximum tree depth.
min_segment_share: Minimum fraction of visitors per leaf
(sklearn ``min_weight_fraction_leaf``).
n_bootstrap: Bootstrap tree refits behind ``stability_score``;
``0`` skips stability (NaN sentinel plus ``UserWarning``).
bootstrap_seed: Seed for the bootstrap resampling RNG.
Returns:
``PolicyTreeResult`` with one segment per leaf, ordered by
sklearn leaf id; ``result.observed`` is ``self.observed`` by
identity.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._policy_tree import fit_policy_tree as _impl
return _impl(
self,
max_depth=max_depth,
min_segment_share=min_segment_share,
n_bootstrap=n_bootstrap,
bootstrap_seed=bootstrap_seed,
)
[docs]
def apply_calibration(self, calibration: Calibration) -> HurdleBCFResult:
"""Return a new posterior with *calibration* attached.
Attach, don't transform: the artifact is stashed on the returned
copy (``is_calibrated=True``); every sample array is shared with
this posterior by identity. The correction currently applies to
intervals only — probabilities and expected losses stay raw;
corrected CIs appear where interval summaries are built. K = 2
experiments only (per-contrast recalibration for K >= 3 is not
yet implemented).
Args:
calibration: SBC-fitted ``Calibration`` whose regime (metric,
n_treatments) must match ``self.observed``.
Returns:
New ``HurdleBCFResult`` carrying the artifact; the original
is untouched.
Raises:
ValueError: When ``self.observed`` is ``None``, or on a
regime mismatch (message names the mismatched dimensions).
NotImplementedError: At K >= 3.
"""
from pytyche.analysis._calibrate import apply_calibration as _impl
return _impl(self, calibration)
[docs]
def recommendation_summary(
self,
treatment: str,
segment: DiscoveredSegment | None = None,
*,
thresholds: DecisionThresholds | None = None,
min_practical_effect: float = 0.02,
) -> RecommendationSummary:
"""Act-now SHIP / CONTINUE / STOP recommendation for one treatment.
The treatment's metric-native contrast draws are scoped
(``segment=None`` is the global all-visitors snapshot; a segment
restricts to its rule's members), reduced to per-draw mean lift,
and summarized under the legacy ``compare.variants`` decision
rule. v0.2 raw scope: probabilities and expected losses come from
the raw draws even on a calibrated posterior — interval
corrections land where intervals are built.
Args:
treatment: Treatment variant name (vs control).
segment: ``None`` for the global snapshot; a
``DiscoveredSegment`` restricts the computation to its
members.
thresholds: Decision thresholds; ``DecisionThresholds()``
defaults when ``None``.
min_practical_effect: Minimum meaningful lift for
``probability_better`` / ``probability_harmful``.
Returns:
``RecommendationSummary`` with the decision, its evidence,
and ``expected_value_of_one_more_round`` always populated
(closed-form preposterior EVSI; formula in
``docs/concepts/decision-theoretic-inputs.md``).
Raises:
ValueError: When ``self.observed`` is ``None``, when
*treatment* is not a treatment name, or when the
segment's rule matches zero visitors.
"""
from pytyche.analysis._recommendation import (
recommendation_summary as _impl,
)
return _impl(
self,
treatment,
segment=segment,
thresholds=thresholds,
min_practical_effect=min_practical_effect,
)
[docs]
def analyze(
self,
*,
max_depth: int = 3,
min_segment_share: float = 0.10,
n_bootstrap: int = 50,
bootstrap_seed: int = 0,
) -> AnalysisResult:
"""The canonical one-call analysis summary for this posterior.
Composes per-treatment ``Comparison`` summaries, the embedded
policy-tree segmentation (keyword arguments forward to it), the
global ``RecommendationSummary`` for the best challenger, and the
posterior-mean per-visitor CATEs. Anything needing posterior
samples goes through ``analysis.posterior``.
Args:
max_depth: Embedded policy tree depth.
min_segment_share: Minimum per-leaf population share.
n_bootstrap: Stability bootstrap count (``0`` skips stability
with a ``UserWarning``).
bootstrap_seed: Stability bootstrap seed.
Returns:
``AnalysisResult``; ``analysis.is_calibrated`` reads through
to this posterior's flag.
Raises:
ValueError: When ``self.observed`` is ``None``.
"""
from pytyche.analysis._analyze import analyze as _impl
return _impl(
self,
max_depth=max_depth,
min_segment_share=min_segment_share,
n_bootstrap=n_bootstrap,
bootstrap_seed=bootstrap_seed,
)
[docs]
def evaluate_against_truth(
self,
tree: PolicyTreeResult,
truth: CalibrationTruth | None,
) -> TruthComparison:
"""Sim-mode evaluation of *tree*'s policy against ground truth.
Args:
tree: The fitted policy whose assignments are evaluated.
truth: Ground truth from the simulation path; ``None`` in
real-data mode (raises — nothing to evaluate against).
Returns:
``TruthComparison`` (cate_rmse, policy_accuracy, and the
realized-RPV trio with the oracle gap).
Raises:
RuntimeError: When *truth* is ``None`` (real-data mode).
ValueError: When ``self.observed`` is ``None`` or the truth
lacks the K-appropriate contrast / potential-outcome
fields.
"""
from pytyche.analysis._truth import evaluate_against_truth as _impl
return _impl(self, tree=tree, truth=truth)
[docs]
def has_credible_segments(self, threshold: float = 0.80) -> bool:
"""Whether some discovered segment clears *threshold* stability.
Runs ``fit_policy_tree`` at its defaults (deterministic given the
default ``bootstrap_seed``) and checks for a segment with
``stability_score >= threshold``. The 0.80 default matches the
default graduation rule's SHIP-gate stability threshold.
Args:
threshold: Minimum bootstrap-replicability stability score.
Returns:
``True`` iff at least one discovered segment clears it.
"""
tree = self.fit_policy_tree()
return any(
score >= threshold for score in tree.stability_scores.values()
)
[docs]
def has_decomposition(self) -> bool:
"""Whether this posterior carries the conversion/severity split.
Returns:
``True`` — the hurdle posterior decomposes into the
conversion and severity channels.
"""
return True