Source code for pytyche.bcf.hurdle

"""Public fit entry point for hurdle BCF.

The canonical public surface is ``fit_hurdle_bcf``, which accepts an
:class:`~pytyche.contracts.ObservedExperimentData` and dispatches to the
appropriate private raw-array core based on the ``pooling`` argument.

Private cores (raw-array, no observed stashing):
    - ``_fit_joint_hurdle_bcf`` — joint shared-tree MCMC (model.py)
    - ``_fit_independent_hurdle_bcf`` — two-stage binary+continuous (compose.py)
"""

from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING, Literal

import numpy as np

from pytyche._internal.extraction import extract_fit_arrays, require_feature_columns
from pytyche._internal.observed import freeze_observed
from pytyche.bcf.config import GPUBCFConfig, HurdleBCFResult
from pytyche.bcf.hurdle.compose import _fit_independent_hurdle_bcf
from pytyche.bcf.hurdle.model import _fit_joint_hurdle_bcf
from pytyche.setup import _warn_if_no_cuda

if TYPE_CHECKING:
    from pytyche.calibrate.artifact import Calibration
    from pytyche.contracts import ObservedExperimentData

_VALID_COPY_MODES: tuple[str, ...] = ("view", "deep", "ref")


[docs] def fit_hurdle_bcf( observed: ObservedExperimentData, *, pooling: Literal["joint", "independent"] = "joint", observed_copy: Literal["view", "deep", "ref"] = "view", calibration: Calibration | None = None, seed: int = 0, progress: bool = False, **kwargs, ) -> HurdleBCFResult: """Fit hurdle BCF on observed experiment data. Dispatches to the joint shared-tree MCMC (``pooling='joint'``) or the independent two-stage estimator (``pooling='independent'``) based on the ``pooling`` argument. All MCMC configuration is forwarded through ``**kwargs`` to :class:`~pytyche.bcf.config.GPUBCFConfig`. Args: observed: Observed experiment data. Must have at least two variants. pooling: ``'joint'`` (default) — shared-tree MCMC estimator (canonical); ``'independent'`` — two-stage binary + continuous estimator. observed_copy: Copy mode for stashing the observed data on the result. ``'view'`` (default) — zero-copy read-only view; ``'deep'`` — independent deep copy; ``'ref'`` — identity (no copy). calibration: Optional SBC-fitted artifact applied post-fit — sugar for ``.apply_calibration(calibration)`` on the returned result (one application path; the kwarg adds no numerics). ``None`` (default) returns the raw posterior. seed: Random seed forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig` as ``random_seed``. progress: When True, the joint fit renders tqdm progress bars on stderr (GFR warm-start sweeps, MCMC chunks). The default False is fully silent. ``pooling='independent'`` accepts the flag but has no per-phase reporting hooks — its channel fits run inside JIT-compiled loops. **kwargs: Additional keyword arguments forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig`. Unknown keys raise ``TypeError`` (frozen dataclass with no extra fields). Returns: :class:`~pytyche.bcf.config.HurdleBCFResult` with ``pooling`` and ``observed`` set to the stashed snapshot; calibrated iff ``calibration`` was supplied. Raises: ValueError: If ``pooling`` is not ``'joint'`` or ``'independent'``, if any variant contributes zero rows (missing treatment level), or — propagated from ``apply_calibration`` — if a supplied ``calibration``'s fitted regime (metric, n_treatments, pooling) does not match this fit. NotImplementedError: If ``pooling='independent'`` and K >= 3 (use ``pooling='joint'`` for multi-arm experiments). TypeError: If ``**kwargs`` contains unknown :class:`~pytyche.bcf.config.GPUBCFConfig` field names. """ _warn_if_no_cuda() # --- Validate pooling --- if pooling not in ("joint", "independent"): raise ValueError( f"pooling must be 'joint' or 'independent'; got {pooling!r}." ) # --- Extract arrays from observed --- extract = extract_fit_arrays(observed) require_feature_columns(extract) # K is the number of variants (arms), not inferred from Z.max(). # Using len(observed.variants) catches the case where a variant contributes # zero rows — extract.Z would lack that level even though the arm exists. K = len(observed.variants) # --- Missing-level guard (fires before config construction / MCMC) --- observed_levels = sorted(int(v) for v in np.unique(extract.Z.astype(int)).tolist()) expected_range = list(range(K)) if observed_levels != expected_range: raise ValueError( f"Every treatment level 0..{K - 1} must appear in the observed data. " f"Observed levels: {observed_levels}; expected range: 0..{K - 1}. " "A missing level means a variant contributed zero rows — " "re-index or remove empty variants." ) # --- Independent-pooling multi-arm guard --- if pooling == "independent" and K >= 3: raise NotImplementedError( f"pooling='independent' supports only K=2 (binary treatment); got K={K}. " "Use pooling='joint' for multi-arm (K >= 3) experiments." ) # --- Validate observed_copy before config construction --- if observed_copy not in _VALID_COPY_MODES: raise ValueError( f"observed_copy must be one of {_VALID_COPY_MODES!r}; got {observed_copy!r}." ) # Narrow to the Literal type that freeze_observed expects. _copy_mode = observed_copy # pyright: ignore[reportAssignmentType] copy_mode_lit: Literal["view", "deep", "ref"] = _copy_mode # type: ignore[assignment] # --- Build config (TypeError on unknown kwargs falls out naturally) --- config = GPUBCFConfig(random_seed=seed, **kwargs) # --- Freeze observed snapshot --- frozen = freeze_observed(observed, copy_mode_lit) # --- Dispatch to private raw-array core --- if pooling == "joint": result = _fit_joint_hurdle_bcf( extract.X, extract.Z, extract.Y, extract.propensity, config, progress=progress, ) # Joint core already stamps pooling='joint'; stash observed on top. result = dataclasses.replace(result, observed=frozen) if calibration is not None: result = result.apply_calibration(calibration) return result # pooling == "independent" d = _fit_independent_hurdle_bcf( extract.X, extract.Z, extract.Y, extract.propensity, config, ) # Extract continuous sub-result for tau0 derivation. # tau0 = 1 / sigma2_orig where sigma2_orig = sigma2 * y_std^2. # This is the global severity noise precision — the same quantity the joint # model estimates as tau_0 — so it is honest, not invented. cont = d["continuous_result"] y_std = float(cont.y_std) sigma2_orig = np.asarray(cont.sigma2_samples) * (y_std * y_std) tau0_out = (1.0 / np.maximum(sigma2_orig, 1e-30)).astype(np.float32) # Per-draw channel arrays p0_samples = d["p0_samples"] # (n, S) p1_samples = d["p1_samples"] # (n, S) sev0_samples = d["sev0_samples"] # (n, S) sev1_samples = d["sev1_samples"] # (n, S) # Paired means (per-draw means of channel arrays — cheap, honest) p0_mean = p0_samples.mean(axis=1) p1_mean = p1_samples.mean(axis=1) sev0_mean = sev0_samples.mean(axis=1) sev1_mean = sev1_samples.mean(axis=1) result = HurdleBCFResult( rpv_cate_samples=np.asarray(d["rpv_cate_samples"], dtype=np.float32), p0_mean=p0_mean, p1_mean=p1_mean, sev0_mean=sev0_mean, sev1_mean=sev1_mean, tau0_samples=tau0_out, tau_hat_quantiles=None, wall_clock_seconds=float(d["wall_clock_seconds"]), num_chains=config.num_chains, num_gfr_sweeps=config.num_gfr_sweeps, diagnostics=None, phase_timing=None, p0_samples=p0_samples, p1_samples=p1_samples, sev0_samples=sev0_samples, sev1_samples=sev1_samples, p_samples=None, sev_samples=None, topology_history=None, observed=frozen, is_calibrated=False, pooling="independent", ) if calibration is not None: result = result.apply_calibration(calibration) return result