Source code for pytyche.bcf.config

"""Configuration dataclass, result types, and small utilities for the GPU BCF.

This module holds the user-facing configuration object (``GPUBCFConfig``), the
three result dataclasses returned by the ``fit_*`` entry points, the
formula-driven ``compute_num_trees_tau`` helper, and the leaf-index dtype
selector used to size heap-layout tree arrays. Pure types and utilities — no
JIT-compiled code, no GPU device handles, no module-level state. Importing
this module is cheap and triggers no GPU work.

Import graph
------------
``gpu_bcf_config`` depends on JAX (for the leaf-index dtype helper),
numpy (for result-array typing), and scipy.stats (for the inverse-normal
quantile in ``compute_num_trees_tau``). It does not import from any sibling
``gpu_bcf_*`` module. The orchestrator and downstream modules import FROM
here, never the other way around.

Contents
--------
``_leaf_index_dtype`` — smallest unsigned int dtype for heap node indices at a given tree depth.
``GPUBCFConfig`` — frozen dataclass of MCMC and prior hyperparameters for the GPU BCF.
``compute_num_trees_tau`` — formula for the minimum tau-forest tree count at a target CI coverage.
``ContinuousBCFResult`` — result container for ``fit_continuous_bcf``.
``BinaryBCFResult`` — result container for ``fit_binary_bcf``.
``HurdleBCFResult`` — result container for ``fit_hurdle_bcf``.

"""

from __future__ import annotations

import dataclasses
import math
from typing import TYPE_CHECKING, Any, Literal

import jax.numpy as jnp
import numpy as np
from scipy.stats import norm as sp_norm

if TYPE_CHECKING:
    from collections.abc import Sequence

    from pytyche.analysis._policy_tree import PolicyTreeResult
    from pytyche.analysis._truth import TruthComparison
    from pytyche.bcf.diagnostics.topology import TopologyHistory
    from pytyche.calibrate.artifact import Calibration
    from pytyche.contracts import (
        AnalysisResult,
        CalibrationTruth,
        DecisionThresholds,
        DiscoveredSegment,
        ObservedExperimentData,
        RecommendationSummary,
    )


# ---------------------------------------------------------------------------
# Leaf index dtype selection (VRAM optimization)
# ---------------------------------------------------------------------------


def _leaf_index_dtype(max_depth: int) -> jnp.dtype:
    """Smallest unsigned int dtype for heap node indices at given tree depth.

    bartz uses 1-indexed heap layout (root=1, left=2i, right=2i+1).
    Max node ID at depth d is 2^d - 1 (e.g., 63 at depth 6).
    """
    max_node = 2 ** max_depth - 1
    if max_node <= 255:     # max_depth <= 7
        return jnp.uint8
    if max_node <= 65535:   # max_depth <= 15
        return jnp.uint16
    return jnp.int32


# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------


[docs] @dataclasses.dataclass(frozen=True) class GPUBCFConfig: """Sampling configuration for GPU BCF via bartz. Attributes ---------- num_burnin: Number of MCMC burn-in iterations (discarded). num_mcmc: Number of MCMC samples to retain for posterior inference. num_trees_mu: Number of trees in the prognostic (mu) forest. num_trees_tau: Number of trees in the treatment effect (tau) forest. max_depth: Maximum tree depth (controls p_nonterminal array length). alpha_mu / beta_mu: Tree prior hyperparameters for mu forest. alpha_tau / beta_tau: Tree prior hyperparameters for tau forest (tighter = more regularized). num_cuts: Number of quantile-based split cutpoints per covariate. random_seed: Seed for JAX PRNG. num_chains: Number of parallel MCMC chains (vmapped). 1 = single-chain (legacy). diagnostic_interval: Iterations per chunk for between-chunk diagnostics. Must divide both num_burnin and num_mcmc evenly. thin_factor: Keep every thin_factor-th sample during MCMC (1 = keep all). retain_topology_history: If True, retain per-iter per-tree topology hashes and move metadata on ``HurdleBCFResult.topology_history`` for mobility diagnostics. Default False (no retention, byte-identical to pre-feature behaviour). """ num_burnin: int = 200 num_mcmc: int = 200 num_trees_mu: int = 200 num_trees_tau: int = 50 max_depth: int = 6 alpha_mu: float = 0.95 beta_mu: float = 2.0 alpha_tau: float = 0.75 beta_tau: float = 3.0 num_cuts: int = 100 random_seed: int = 42 num_chains: int = 1 diagnostic_interval: int = 50 thin_factor: int = 1 num_gfr_sweeps: int = 5 min_samples_leaf: int = 5 gfr_backend: str = "gpu" # grow-from-root (GFR) warm start: "gpu" = JAX, "cpu" = StochTree trace_path: str | None = None # If set, write raw channel .npz per chunk # Heteroscedastic severity (per-leaf precision) hyperparameters var_tau_sev: float = 0.5 # Signal-to-noise for gamma prior moment-matching (Linero default) kappa_sev: float = 1.0 # Normal-Gamma prior precision scaling for beta tau0_a_prior: float = 1.0 # Global precision shape prior tau0_b_prior: float = 1.0 # Global precision rate prior freeze_gamma: bool = False # Debug: hold gamma at init, skip recentering retain_channel_samples: bool = True # Retain per-draw channel arrays on CPU (the conversion/severity decomposition needs them; set False to skip the transfer at sweep scale) focal_severity: bool = False # Upweight severity LML by n/n_conv so each converter contributes equally to splits # When True, severity channel uses per-leaf Normal-Gamma sampling with a # multiplicative tau_hat update. When False (default), severity uses the # parallel-LML Normal-Normal path mirroring bartz; tau_hat stays at 1.0 # and the global tau_0 absorbs all severity precision. The False path is # ~3-5x faster per iteration on GPU; the True path is preserved for # backward compatibility and rarely needed in practice. per_leaf_gamma: bool = False # Retain per-iter per-tree topology hashes + move metadata across the full # burn-in + MCMC trace. When True, ``HurdleBCFResult.topology_history`` is # populated with a ``TopologyHistory`` instance for downstream mobility # diagnostics; when False (default), it stays ``None`` and the fit's # wall-clock + PRNG behaviour is bitwise-identical to HEAD pre-this-change. # Memory cost: ~14 bytes per (chain, iter, tree) entry — ~4 MB at the # production default of 4 chains x 300 iters x 250 trees, ~140 MB at # 10000 iters with the same shape. Opt-in for predictability. retain_topology_history: bool = False
[docs] def compute_num_trees_tau( n: int, d_tau: float = 3.0, sigma_tau: float = 0.5, coverage: float = 0.90, floor: int = 50, ceiling: int = 400, ) -> int: """Formula-driven tau tree count for target CI coverage. T_min = ceil((d_tau * sigma_tau * sqrt(n) / (2 * z))^(2/3)) The tau forest's piecewise-constant approximation has O(1) bias that dominates the posterior at large n (O(1/sqrt(n)) concentration). This formula computes the minimum T to keep bias below the CI half-width at the target coverage level. Parameters ---------- n : cumulative sample size d_tau : effective CATE dimensionality (treatment effect modifiers) sigma_tau : CATE heterogeneity on standardized/probit scale coverage : target CI coverage level (0.90 = 90%) floor : minimum tree count (BCF needs some for variable selection) ceiling : maximum tree count (VRAM bound) """ z = sp_norm.ppf(1 - (1 - coverage) / 2) t_min = (d_tau * sigma_tau * math.sqrt(n) / (2 * z)) ** (2 / 3) return int(min(max(math.ceil(t_min), floor), ceiling))
# --------------------------------------------------------------------------- # Result containers # ---------------------------------------------------------------------------
[docs] @dataclasses.dataclass(frozen=True) class ContinuousBCFResult: """Result from continuous BCF. Attributes ---------- mu_samples: ``(n, num_mcmc)`` prognostic predictions (standardized). tau_samples: ``(n, num_mcmc)`` treatment effects (standardized). sigma2_samples: ``(num_mcmc,)`` error variance (standardized). y_bar: Mean of the outcome used for standardization. y_std: Standard deviation of the outcome used for standardization. wall_clock_seconds: Wall-clock time for the fit in seconds. observed: The ``ObservedExperimentData`` the fit consumed, attached to the result so the analysis methods can reach the visitor rows and variant metadata. ``None`` when constructed by private raw-array helpers; populated by the public fit wrappers. is_calibrated: ``True`` only after ``apply_calibration`` has been called on this result. Defaults to ``False``. calibration: The ``Calibration`` artifact attached by ``apply_calibration``; ``None`` on fresh fits. The v0.2 artifact scope is interval corrections only — it is consumed where interval summaries are built, never to transform sample arrays. """ mu_samples: np.ndarray # (n, num_mcmc) prognostic predictions (standardized) tau_samples: np.ndarray # (n, num_mcmc) treatment effects (standardized) sigma2_samples: np.ndarray # (num_mcmc,) error variance (standardized) y_bar: float y_std: float wall_clock_seconds: float observed: ObservedExperimentData | None = dataclasses.field(default=None, kw_only=True) is_calibrated: bool = dataclasses.field(default=False, kw_only=True) calibration: Calibration | None = dataclasses.field(default=None, kw_only=True, repr=False)
[docs] def thompson_allocation( self, segments: Sequence[DiscoveredSegment], epsilon: float = 0.02, ) -> dict[int, dict[str, float]]: """Per-segment traffic split: each arm's weight is the posterior probability that it is the segment's best arm. Thompson sampling at segment granularity: per segment, each posterior draw votes for its best arm (the largest member-mean contrast, or control when none is positive); an arm's weight is its win frequency over draws. Args: segments: Segments to allocate over (only ``id`` and ``rule`` are consumed); membership is resolved against ``self.observed``. epsilon: Safety-net exploration floor — arms below ``epsilon / K`` are raised to the floor and the rest rescaled, so no arm's traffic is starved to zero; inert when every arm is already above it. NOT the dial for how much traffic stays on control — that is ``min_control_weight`` / ``min_explore_weight`` on ``pt.sequential_experiment``; rarely worth overriding. Returns: ``{segment.id: {variant_name: weight}}`` — inner dicts in variant order (control first), each summing to 1. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._thompson import thompson_allocation as _impl return _impl(self, segments=segments, epsilon=epsilon)
[docs] def fit_policy_tree( self, *, max_depth: int = 3, min_segment_share: float = 0.10, n_bootstrap: int = 50, bootstrap_seed: int = 0, ) -> PolicyTreeResult: """Discover interpretable segments from the posterior's per-visitor treatment effects, by fitting a shallow decision tree. Each visitor is labeled with the arm the posterior expects to be best for them (largest posterior-mean lift, or control when no lift is positive); a multiclass decision tree is fit on the visitors' features, and each leaf becomes a ``DiscoveredSegment`` carrying an exact membership rule, gate estimate/CI, per-arm best probabilities, Thompson allocation, and bootstrap-replicability stability. Args: max_depth: Maximum tree depth. min_segment_share: Minimum fraction of visitors per leaf (sklearn ``min_weight_fraction_leaf``). n_bootstrap: Bootstrap tree refits behind ``stability_score``; ``0`` skips stability (NaN sentinel plus ``UserWarning``). bootstrap_seed: Seed for the bootstrap resampling RNG. Returns: ``PolicyTreeResult`` with one segment per leaf, ordered by sklearn leaf id; ``result.observed`` is ``self.observed`` by identity. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._policy_tree import fit_policy_tree as _impl return _impl( self, max_depth=max_depth, min_segment_share=min_segment_share, n_bootstrap=n_bootstrap, bootstrap_seed=bootstrap_seed, )
[docs] def apply_calibration(self, calibration: Calibration) -> ContinuousBCFResult: """Return a new posterior with *calibration* attached. Attach, don't transform: the artifact is stashed on the returned copy (``is_calibrated=True``); every sample array is shared with this posterior by identity. The correction currently applies to intervals only — probabilities and expected losses stay raw; corrected CIs appear where interval summaries are built. K = 2 experiments only (per-contrast recalibration for K >= 3 is not yet implemented). Args: calibration: SBC-fitted ``Calibration`` whose regime (metric, n_treatments) must match ``self.observed``. Returns: New ``ContinuousBCFResult`` carrying the artifact; the original is untouched. Raises: ValueError: When ``self.observed`` is ``None``, or on a regime mismatch (message names the mismatched dimensions). NotImplementedError: At K >= 3. """ from pytyche.analysis._calibrate import apply_calibration as _impl return _impl(self, calibration)
[docs] def recommendation_summary( self, treatment: str, segment: DiscoveredSegment | None = None, *, thresholds: DecisionThresholds | None = None, min_practical_effect: float = 0.02, ) -> RecommendationSummary: """Act-now SHIP / CONTINUE / STOP recommendation for one treatment. The treatment's metric-native contrast draws are scoped (``segment=None`` is the global all-visitors snapshot; a segment restricts to its rule's members), reduced to per-draw mean lift, and summarized under the legacy ``compare.variants`` decision rule. v0.2 raw scope: probabilities and expected losses come from the raw draws even on a calibrated posterior — interval corrections land where intervals are built. Args: treatment: Treatment variant name (vs control). segment: ``None`` for the global snapshot; a ``DiscoveredSegment`` restricts the computation to its members. thresholds: Decision thresholds; ``DecisionThresholds()`` defaults when ``None``. min_practical_effect: Minimum meaningful lift for ``probability_better`` / ``probability_harmful``. Returns: ``RecommendationSummary`` with the decision, its evidence, and ``expected_value_of_one_more_round`` always populated (closed-form preposterior EVSI; formula in ``docs/concepts/decision-theoretic-inputs.md``). Raises: ValueError: When ``self.observed`` is ``None``, when *treatment* is not a treatment name, or when the segment's rule matches zero visitors. """ from pytyche.analysis._recommendation import ( recommendation_summary as _impl, ) return _impl( self, treatment, segment=segment, thresholds=thresholds, min_practical_effect=min_practical_effect, )
[docs] def analyze( self, *, max_depth: int = 3, min_segment_share: float = 0.10, n_bootstrap: int = 50, bootstrap_seed: int = 0, ) -> AnalysisResult: """The canonical one-call analysis summary for this posterior. Composes per-treatment ``Comparison`` summaries, the embedded policy-tree segmentation (keyword arguments forward to it), the global ``RecommendationSummary`` for the best challenger, and the posterior-mean per-visitor CATEs. Anything needing posterior samples goes through ``analysis.posterior``. Args: max_depth: Embedded policy tree depth. min_segment_share: Minimum per-leaf population share. n_bootstrap: Stability bootstrap count (``0`` skips stability with a ``UserWarning``). bootstrap_seed: Stability bootstrap seed. Returns: ``AnalysisResult``; ``analysis.is_calibrated`` reads through to this posterior's flag. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._analyze import analyze as _impl return _impl( self, max_depth=max_depth, min_segment_share=min_segment_share, n_bootstrap=n_bootstrap, bootstrap_seed=bootstrap_seed, )
[docs] def evaluate_against_truth( self, tree: PolicyTreeResult, truth: CalibrationTruth | None, ) -> TruthComparison: """Sim-mode evaluation of *tree*'s policy against ground truth. Args: tree: The fitted policy whose assignments are evaluated. truth: Ground truth from the simulation path; ``None`` in real-data mode (raises — nothing to evaluate against). Returns: ``TruthComparison`` (cate_rmse, policy_accuracy, and the realized-RPV trio with the oracle gap). Raises: RuntimeError: When *truth* is ``None`` (real-data mode). ValueError: When ``self.observed`` is ``None`` or the truth lacks the K-appropriate contrast / potential-outcome fields. """ from pytyche.analysis._truth import evaluate_against_truth as _impl return _impl(self, tree=tree, truth=truth)
[docs] def has_credible_segments(self, threshold: float = 0.80) -> bool: """Whether some discovered segment clears *threshold* stability. Runs ``fit_policy_tree`` at its defaults (deterministic given the default ``bootstrap_seed``) and checks for a segment with ``stability_score >= threshold``. The 0.80 default matches the default graduation rule's SHIP-gate stability threshold. Args: threshold: Minimum bootstrap-replicability stability score. Returns: ``True`` iff at least one discovered segment clears it. """ tree = self.fit_policy_tree() return any( score >= threshold for score in tree.stability_scores.values() )
[docs] def has_decomposition(self) -> bool: """Whether this posterior carries the conversion/severity split. Returns: ``False`` — only hurdle posteriors carry the conversion/severity decomposition. """ return False
[docs] @dataclasses.dataclass(frozen=True) class BinaryBCFResult: """Result from binary (probit) BCF. Attributes ---------- mu_samples: ``(n, num_mcmc)`` prognostic predictions (probit scale). tau_samples: ``(n, num_mcmc)`` treatment effects (probit scale). wall_clock_seconds: Wall-clock time for the fit in seconds. observed: The ``ObservedExperimentData`` the fit consumed, attached to the result so the analysis methods can reach the visitor rows and variant metadata. ``None`` when constructed by private raw-array helpers; populated by the public fit wrappers. is_calibrated: ``True`` only after ``apply_calibration`` has been called on this result. Defaults to ``False``. calibration: The ``Calibration`` artifact attached by ``apply_calibration``; ``None`` on fresh fits. The v0.2 artifact scope is interval corrections only — it is consumed where interval summaries are built, never to transform sample arrays. """ mu_samples: np.ndarray # (n, num_mcmc) prognostic predictions (probit scale) tau_samples: np.ndarray # (n, num_mcmc) treatment effects (probit scale) wall_clock_seconds: float observed: ObservedExperimentData | None = dataclasses.field(default=None, kw_only=True) is_calibrated: bool = dataclasses.field(default=False, kw_only=True) calibration: Calibration | None = dataclasses.field(default=None, kw_only=True, repr=False)
[docs] def thompson_allocation( self, segments: Sequence[DiscoveredSegment], epsilon: float = 0.02, ) -> dict[int, dict[str, float]]: """Per-segment traffic split: each arm's weight is the posterior probability that it is the segment's best arm. Thompson sampling at segment granularity: per segment, each posterior draw votes for its best arm (the largest member-mean contrast, or control when none is positive); an arm's weight is its win frequency over draws. Args: segments: Segments to allocate over (only ``id`` and ``rule`` are consumed); membership is resolved against ``self.observed``. epsilon: Safety-net exploration floor — arms below ``epsilon / K`` are raised to the floor and the rest rescaled, so no arm's traffic is starved to zero; inert when every arm is already above it. NOT the dial for how much traffic stays on control — that is ``min_control_weight`` / ``min_explore_weight`` on ``pt.sequential_experiment``; rarely worth overriding. Returns: ``{segment.id: {variant_name: weight}}`` — inner dicts in variant order (control first), each summing to 1. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._thompson import thompson_allocation as _impl return _impl(self, segments=segments, epsilon=epsilon)
[docs] def fit_policy_tree( self, *, max_depth: int = 3, min_segment_share: float = 0.10, n_bootstrap: int = 50, bootstrap_seed: int = 0, ) -> PolicyTreeResult: """Discover interpretable segments from the posterior's per-visitor treatment effects, by fitting a shallow decision tree. Each visitor is labeled with the arm the posterior expects to be best for them (largest posterior-mean lift, or control when no lift is positive); a multiclass decision tree is fit on the visitors' features, and each leaf becomes a ``DiscoveredSegment`` carrying an exact membership rule, gate estimate/CI, per-arm best probabilities, Thompson allocation, and bootstrap-replicability stability. Args: max_depth: Maximum tree depth. min_segment_share: Minimum fraction of visitors per leaf (sklearn ``min_weight_fraction_leaf``). n_bootstrap: Bootstrap tree refits behind ``stability_score``; ``0`` skips stability (NaN sentinel plus ``UserWarning``). bootstrap_seed: Seed for the bootstrap resampling RNG. Returns: ``PolicyTreeResult`` with one segment per leaf, ordered by sklearn leaf id; ``result.observed`` is ``self.observed`` by identity. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._policy_tree import fit_policy_tree as _impl return _impl( self, max_depth=max_depth, min_segment_share=min_segment_share, n_bootstrap=n_bootstrap, bootstrap_seed=bootstrap_seed, )
[docs] def apply_calibration(self, calibration: Calibration) -> BinaryBCFResult: """Return a new posterior with *calibration* attached. Attach, don't transform: the artifact is stashed on the returned copy (``is_calibrated=True``); every sample array is shared with this posterior by identity. The correction currently applies to intervals only — probabilities and expected losses stay raw; corrected CIs appear where interval summaries are built. K = 2 experiments only (per-contrast recalibration for K >= 3 is not yet implemented). Args: calibration: SBC-fitted ``Calibration`` whose regime (metric, n_treatments) must match ``self.observed``. Returns: New ``BinaryBCFResult`` carrying the artifact; the original is untouched. Raises: ValueError: When ``self.observed`` is ``None``, or on a regime mismatch (message names the mismatched dimensions). NotImplementedError: At K >= 3. """ from pytyche.analysis._calibrate import apply_calibration as _impl return _impl(self, calibration)
[docs] def recommendation_summary( self, treatment: str, segment: DiscoveredSegment | None = None, *, thresholds: DecisionThresholds | None = None, min_practical_effect: float = 0.02, ) -> RecommendationSummary: """Act-now SHIP / CONTINUE / STOP recommendation for one treatment. The treatment's metric-native contrast draws are scoped (``segment=None`` is the global all-visitors snapshot; a segment restricts to its rule's members), reduced to per-draw mean lift, and summarized under the legacy ``compare.variants`` decision rule. v0.2 raw scope: probabilities and expected losses come from the raw draws even on a calibrated posterior — interval corrections land where intervals are built. Args: treatment: Treatment variant name (vs control). segment: ``None`` for the global snapshot; a ``DiscoveredSegment`` restricts the computation to its members. thresholds: Decision thresholds; ``DecisionThresholds()`` defaults when ``None``. min_practical_effect: Minimum meaningful lift for ``probability_better`` / ``probability_harmful``. Returns: ``RecommendationSummary`` with the decision, its evidence, and ``expected_value_of_one_more_round`` always populated (closed-form preposterior EVSI; formula in ``docs/concepts/decision-theoretic-inputs.md``). Raises: ValueError: When ``self.observed`` is ``None``, when *treatment* is not a treatment name, or when the segment's rule matches zero visitors. """ from pytyche.analysis._recommendation import ( recommendation_summary as _impl, ) return _impl( self, treatment, segment=segment, thresholds=thresholds, min_practical_effect=min_practical_effect, )
[docs] def analyze( self, *, max_depth: int = 3, min_segment_share: float = 0.10, n_bootstrap: int = 50, bootstrap_seed: int = 0, ) -> AnalysisResult: """The canonical one-call analysis summary for this posterior. Composes per-treatment ``Comparison`` summaries, the embedded policy-tree segmentation (keyword arguments forward to it), the global ``RecommendationSummary`` for the best challenger, and the posterior-mean per-visitor CATEs. Anything needing posterior samples goes through ``analysis.posterior``. Args: max_depth: Embedded policy tree depth. min_segment_share: Minimum per-leaf population share. n_bootstrap: Stability bootstrap count (``0`` skips stability with a ``UserWarning``). bootstrap_seed: Stability bootstrap seed. Returns: ``AnalysisResult``; ``analysis.is_calibrated`` reads through to this posterior's flag. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._analyze import analyze as _impl return _impl( self, max_depth=max_depth, min_segment_share=min_segment_share, n_bootstrap=n_bootstrap, bootstrap_seed=bootstrap_seed, )
[docs] def evaluate_against_truth( self, tree: PolicyTreeResult, truth: CalibrationTruth | None, ) -> TruthComparison: """Sim-mode evaluation of *tree*'s policy against ground truth. Args: tree: The fitted policy whose assignments are evaluated. truth: Ground truth from the simulation path; ``None`` in real-data mode (raises — nothing to evaluate against). Returns: ``TruthComparison`` (cate_rmse, policy_accuracy, and the realized-RPV trio with the oracle gap). Raises: RuntimeError: When *truth* is ``None`` (real-data mode). ValueError: When ``self.observed`` is ``None`` or the truth lacks the K-appropriate contrast / potential-outcome fields. """ from pytyche.analysis._truth import evaluate_against_truth as _impl return _impl(self, tree=tree, truth=truth)
[docs] def has_credible_segments(self, threshold: float = 0.80) -> bool: """Whether some discovered segment clears *threshold* stability. Runs ``fit_policy_tree`` at its defaults (deterministic given the default ``bootstrap_seed``) and checks for a segment with ``stability_score >= threshold``. The 0.80 default matches the default graduation rule's SHIP-gate stability threshold. Args: threshold: Minimum bootstrap-replicability stability score. Returns: ``True`` iff at least one discovered segment clears it. """ tree = self.fit_policy_tree() return any( score >= threshold for score in tree.stability_scores.values() )
[docs] def has_decomposition(self) -> bool: """Whether this posterior carries the conversion/severity split. Returns: ``False`` — only hurdle posteriors carry the conversion/severity decomposition. """ return False
[docs] @dataclasses.dataclass(frozen=True) class HurdleBCFResult: """Result from joint shared-tree hurdle BCF. Each tree simultaneously estimates conversion (probit) and severity (log-revenue) parameters via shared tree structure. This couples the two channels so splits are jointly informative. RPV CATEs are composed on-GPU (float32) and transferred to CPU for policy tree fitting. Channel-level per-draw arrays (p0, p1, sev0, sev1) are retained by default (``retain_channel_samples=True``) — the conversion/severity decomposition is the headline output of the hurdle approach and needs the per-draw channel arrays for its credible intervals. Set ``retain_channel_samples=False`` to skip the GPU→CPU transfer when memory matters more than the decomposition (e.g. large-n sweep contexts that only consume the composed RPV contrasts). When num_chains > 1, samples are concatenated across chains: S_total = (num_mcmc / thin_factor) * num_chains. **Arm-count dispatch** (``K = int(Z.max()) + 1``). At K = 2 (binary arm) the legacy paired fields are populated — ``p0_mean``/``p1_mean``/``sev0_mean``/ ``sev1_mean`` ``(n,)`` and, when ``retain_channel_samples=True``, ``p0_samples``/``p1_samples``/``sev0_samples``/``sev1_samples`` ``(n, S_total)`` — ``rpv_cate_samples`` is ``(n, S_total)``, and the per-arm fields ``p_samples``/``sev_samples`` are ``None``. At K >= 3 (multi-arm) the per-arm fields are populated instead — ``p_samples``/``sev_samples`` ``(n, S_total, K)`` (when retained) and ``rpv_cate_samples`` ``(n, S_total, K - 1)`` (the jointly sampled contrast posterior) — and the legacy paired fields are ``None``. The two field families are never populated together. ``tau0_samples`` ``(S_total,)`` and the ``sigma2_samples = 1 / tau0_samples`` property are scalar at every K (each visitor sees one outcome, so the severity residual is scalar per visitor — there is no per-arm severity precision). The ``topology_history`` field is populated only when the producing fit set ``GPUBCFConfig.retain_topology_history=True``. When the flag is off (default), the field is ``None`` and the fit's wall-clock + PRNG state is bitwise-identical to HEAD pre-this-change. Attributes ---------- rpv_cate_samples: ``(n, S_total)`` float32 — composed on GPU, transferred to CPU. p0_mean: ``(n,)`` float32 — E[Φ(μ_b + b₀·τ_b)]; None at K>=3. p1_mean: ``(n,)`` float32 — E[Φ(μ_b + b₁·τ_b)]; None at K>=3. sev0_mean: ``(n,)`` float32 — E[exp(μ_c + b₀·τ_c + σ²/2)]; None at K>=3. sev1_mean: ``(n,)`` float32 — E[exp(μ_c + b₁·τ_c + σ²/2)]; None at K>=3. tau0_samples: ``(S_total,)`` float32 — global precision. tau_hat_quantiles: ``(S_total, 5)`` [q05,q25,q50,q75,q95] or None. wall_clock_seconds: Wall-clock time for the fit in seconds. num_chains: Number of parallel MCMC chains used. num_gfr_sweeps: Number of GFR warm-start sweeps performed. diagnostics: Dict of diagnostic values (rhat_tau0, per_chain_ess, etc.), or None. phase_timing: Dict of per-phase wall-clock breakdown, or None. p0_samples: jax.Array ``(n, S_total)`` — P(convert|control) per draw; None if not retained. p1_samples: jax.Array ``(n, S_total)`` — P(convert|treated) per draw; None if not retained. sev0_samples: jax.Array ``(n, S_total)`` — E[sev|control,convert] per draw; None if not retained. sev1_samples: jax.Array ``(n, S_total)`` — E[sev|treated,convert] per draw; None if not retained. p_samples: jax.Array ``(n, S_total, K)`` — per-arm P(convert) per draw; None at K=2. sev_samples: jax.Array ``(n, S_total, K)`` — per-arm E[sev|convert] per draw; None at K=2. topology_history: Topology retention trace; populated only when the producing fit set ``GPUBCFConfig.retain_topology_history=True``. ``None`` otherwise. observed: The ``ObservedExperimentData`` the fit consumed, attached to the result so the analysis methods can reach the visitor rows and variant metadata. ``None`` when constructed by private raw-array helpers; populated by the public fit wrappers. is_calibrated: ``True`` only after ``apply_calibration`` has been called on this result. Defaults to ``False``. calibration: The ``Calibration`` artifact attached by ``apply_calibration``; ``None`` on fresh fits. The v0.2 artifact scope is interval corrections only — it is consumed where interval summaries are built, never to transform sample arrays. pooling: Provenance of the fit: ``"joint"`` = shared-tree canonical fit; ``"independent"`` = two-stage baseline (binary + continuous fitted separately). Required — caller must always populate. """ rpv_cate_samples: np.ndarray # (n, S_total) float32 — composed on GPU, transferred to CPU p0_mean: np.ndarray | None # (n,) float32 — E[Φ(μ_b + b₀·τ_b)]; None at K>=3 p1_mean: np.ndarray | None # (n,) float32 — E[Φ(μ_b + b₁·τ_b)]; None at K>=3 sev0_mean: np.ndarray | None # (n,) float32 — E[exp(μ_c + b₀·τ_c + σ²/2)]; None at K>=3 sev1_mean: np.ndarray | None # (n,) float32 — E[exp(μ_c + b₁·τ_c + σ²/2)]; None at K>=3 tau0_samples: np.ndarray # (S_total,) float32 — global precision tau_hat_quantiles: np.ndarray | None # (S_total, 5) [q05,q25,q50,q75,q95] or None wall_clock_seconds: float num_chains: int = 1 num_gfr_sweeps: int = 0 diagnostics: dict | None = None # rhat_tau0, per_chain_ess, etc. phase_timing: dict | None = None # per-phase wall-clock breakdown # Per-draw channel arrays — JAX arrays on GPU, or None if not retained p0_samples: Any | None = None # jax.Array (n, S_total) — P(convert|control) per draw p1_samples: Any | None = None # jax.Array (n, S_total) — P(convert|treated) per draw sev0_samples: Any | None = None # jax.Array (n, S_total) — E[sev|control,convert] per draw sev1_samples: Any | None = None # jax.Array (n, S_total) — E[sev|treated,convert] per draw p_samples: Any | None = None # jax.Array (n, S_total, K) — per-arm P(convert) per draw; None at K=2 sev_samples: Any | None = None # jax.Array (n, S_total, K) — per-arm E[sev|convert] per draw; None at K=2 # Optional retained topology trace, populated only when the producing fit # set ``GPUBCFConfig.retain_topology_history=True``. ``None`` otherwise. topology_history: "TopologyHistory | None" = None # noqa: UP037 — string form pinned by RESULT_FIELDS_SNAPSHOTS contract observed: ObservedExperimentData | None = dataclasses.field(default=None, kw_only=True) is_calibrated: bool = dataclasses.field(default=False, kw_only=True) calibration: Calibration | None = dataclasses.field(default=None, kw_only=True, repr=False) pooling: Literal["joint", "independent"] = dataclasses.field(kw_only=True) @property def sigma2_samples(self) -> np.ndarray: """Return ``1 / tau0_samples`` as a sigma² view. Backward-compat shim for downstream code that consumes the variance parameterisation rather than the precision one. """ return 1.0 / self.tau0_samples
[docs] def thompson_allocation( self, segments: Sequence[DiscoveredSegment], epsilon: float = 0.02, ) -> dict[int, dict[str, float]]: """Per-segment traffic split: each arm's weight is the posterior probability that it is the segment's best arm. Thompson sampling at segment granularity: per segment, each posterior draw votes for its best arm (the largest member-mean contrast, or control when none is positive); an arm's weight is its win frequency over draws. Args: segments: Segments to allocate over (only ``id`` and ``rule`` are consumed); membership is resolved against ``self.observed``. epsilon: Safety-net exploration floor — arms below ``epsilon / K`` are raised to the floor and the rest rescaled, so no arm's traffic is starved to zero; inert when every arm is already above it. NOT the dial for how much traffic stays on control — that is ``min_control_weight`` / ``min_explore_weight`` on ``pt.sequential_experiment``; rarely worth overriding. Returns: ``{segment.id: {variant_name: weight}}`` — inner dicts in variant order (control first), each summing to 1. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._thompson import thompson_allocation as _impl return _impl(self, segments=segments, epsilon=epsilon)
[docs] def fit_policy_tree( self, *, max_depth: int = 3, min_segment_share: float = 0.10, n_bootstrap: int = 50, bootstrap_seed: int = 0, ) -> PolicyTreeResult: """Discover interpretable segments from the posterior's per-visitor treatment effects, by fitting a shallow decision tree. Each visitor is labeled with the arm the posterior expects to be best for them (largest posterior-mean lift, or control when no lift is positive); a multiclass decision tree is fit on the visitors' features, and each leaf becomes a ``DiscoveredSegment`` carrying an exact membership rule, gate estimate/CI, per-arm best probabilities, Thompson allocation, and bootstrap-replicability stability. Args: max_depth: Maximum tree depth. min_segment_share: Minimum fraction of visitors per leaf (sklearn ``min_weight_fraction_leaf``). n_bootstrap: Bootstrap tree refits behind ``stability_score``; ``0`` skips stability (NaN sentinel plus ``UserWarning``). bootstrap_seed: Seed for the bootstrap resampling RNG. Returns: ``PolicyTreeResult`` with one segment per leaf, ordered by sklearn leaf id; ``result.observed`` is ``self.observed`` by identity. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._policy_tree import fit_policy_tree as _impl return _impl( self, max_depth=max_depth, min_segment_share=min_segment_share, n_bootstrap=n_bootstrap, bootstrap_seed=bootstrap_seed, )
[docs] def apply_calibration(self, calibration: Calibration) -> HurdleBCFResult: """Return a new posterior with *calibration* attached. Attach, don't transform: the artifact is stashed on the returned copy (``is_calibrated=True``); every sample array is shared with this posterior by identity. The correction currently applies to intervals only — probabilities and expected losses stay raw; corrected CIs appear where interval summaries are built. K = 2 experiments only (per-contrast recalibration for K >= 3 is not yet implemented). Args: calibration: SBC-fitted ``Calibration`` whose regime (metric, n_treatments) must match ``self.observed``. Returns: New ``HurdleBCFResult`` carrying the artifact; the original is untouched. Raises: ValueError: When ``self.observed`` is ``None``, or on a regime mismatch (message names the mismatched dimensions). NotImplementedError: At K >= 3. """ from pytyche.analysis._calibrate import apply_calibration as _impl return _impl(self, calibration)
[docs] def recommendation_summary( self, treatment: str, segment: DiscoveredSegment | None = None, *, thresholds: DecisionThresholds | None = None, min_practical_effect: float = 0.02, ) -> RecommendationSummary: """Act-now SHIP / CONTINUE / STOP recommendation for one treatment. The treatment's metric-native contrast draws are scoped (``segment=None`` is the global all-visitors snapshot; a segment restricts to its rule's members), reduced to per-draw mean lift, and summarized under the legacy ``compare.variants`` decision rule. v0.2 raw scope: probabilities and expected losses come from the raw draws even on a calibrated posterior — interval corrections land where intervals are built. Args: treatment: Treatment variant name (vs control). segment: ``None`` for the global snapshot; a ``DiscoveredSegment`` restricts the computation to its members. thresholds: Decision thresholds; ``DecisionThresholds()`` defaults when ``None``. min_practical_effect: Minimum meaningful lift for ``probability_better`` / ``probability_harmful``. Returns: ``RecommendationSummary`` with the decision, its evidence, and ``expected_value_of_one_more_round`` always populated (closed-form preposterior EVSI; formula in ``docs/concepts/decision-theoretic-inputs.md``). Raises: ValueError: When ``self.observed`` is ``None``, when *treatment* is not a treatment name, or when the segment's rule matches zero visitors. """ from pytyche.analysis._recommendation import ( recommendation_summary as _impl, ) return _impl( self, treatment, segment=segment, thresholds=thresholds, min_practical_effect=min_practical_effect, )
[docs] def analyze( self, *, max_depth: int = 3, min_segment_share: float = 0.10, n_bootstrap: int = 50, bootstrap_seed: int = 0, ) -> AnalysisResult: """The canonical one-call analysis summary for this posterior. Composes per-treatment ``Comparison`` summaries, the embedded policy-tree segmentation (keyword arguments forward to it), the global ``RecommendationSummary`` for the best challenger, and the posterior-mean per-visitor CATEs. Anything needing posterior samples goes through ``analysis.posterior``. Args: max_depth: Embedded policy tree depth. min_segment_share: Minimum per-leaf population share. n_bootstrap: Stability bootstrap count (``0`` skips stability with a ``UserWarning``). bootstrap_seed: Stability bootstrap seed. Returns: ``AnalysisResult``; ``analysis.is_calibrated`` reads through to this posterior's flag. Raises: ValueError: When ``self.observed`` is ``None``. """ from pytyche.analysis._analyze import analyze as _impl return _impl( self, max_depth=max_depth, min_segment_share=min_segment_share, n_bootstrap=n_bootstrap, bootstrap_seed=bootstrap_seed, )
[docs] def evaluate_against_truth( self, tree: PolicyTreeResult, truth: CalibrationTruth | None, ) -> TruthComparison: """Sim-mode evaluation of *tree*'s policy against ground truth. Args: tree: The fitted policy whose assignments are evaluated. truth: Ground truth from the simulation path; ``None`` in real-data mode (raises — nothing to evaluate against). Returns: ``TruthComparison`` (cate_rmse, policy_accuracy, and the realized-RPV trio with the oracle gap). Raises: RuntimeError: When *truth* is ``None`` (real-data mode). ValueError: When ``self.observed`` is ``None`` or the truth lacks the K-appropriate contrast / potential-outcome fields. """ from pytyche.analysis._truth import evaluate_against_truth as _impl return _impl(self, tree=tree, truth=truth)
[docs] def has_credible_segments(self, threshold: float = 0.80) -> bool: """Whether some discovered segment clears *threshold* stability. Runs ``fit_policy_tree`` at its defaults (deterministic given the default ``bootstrap_seed``) and checks for a segment with ``stability_score >= threshold``. The 0.80 default matches the default graduation rule's SHIP-gate stability threshold. Args: threshold: Minimum bootstrap-replicability stability score. Returns: ``True`` iff at least one discovered segment clears it. """ tree = self.fit_policy_tree() return any( score >= threshold for score in tree.stability_scores.values() )
[docs] def has_decomposition(self) -> bool: """Whether this posterior carries the conversion/severity split. Returns: ``True`` — the hurdle posterior decomposes into the conversion and severity channels. """ return True