"""Public fit entry point for hurdle BCF.
The canonical public surface is ``fit_hurdle_bcf``, which accepts an
:class:`~pytyche.contracts.ObservedExperimentData` and dispatches to the
appropriate private raw-array core based on the ``pooling`` argument.
Private cores (raw-array, no observed stashing):
- ``_fit_joint_hurdle_bcf`` — joint shared-tree MCMC (model.py)
- ``_fit_independent_hurdle_bcf`` — two-stage binary+continuous (compose.py)
"""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, Literal
import numpy as np
from pytyche._internal.extraction import extract_fit_arrays, require_feature_columns
from pytyche._internal.observed import freeze_observed
from pytyche.bcf.config import GPUBCFConfig, HurdleBCFResult
from pytyche.bcf.hurdle.compose import _fit_independent_hurdle_bcf
from pytyche.bcf.hurdle.model import _fit_joint_hurdle_bcf
from pytyche.setup import _warn_if_no_cuda
if TYPE_CHECKING:
from pytyche.calibrate.artifact import Calibration
from pytyche.contracts import ObservedExperimentData
_VALID_COPY_MODES: tuple[str, ...] = ("view", "deep", "ref")
[docs]
def fit_hurdle_bcf(
observed: ObservedExperimentData,
*,
pooling: Literal["joint", "independent"] = "joint",
observed_copy: Literal["view", "deep", "ref"] = "view",
calibration: Calibration | None = None,
seed: int = 0,
progress: bool = False,
**kwargs,
) -> HurdleBCFResult:
"""Fit hurdle BCF on observed experiment data.
Dispatches to the joint shared-tree MCMC (``pooling='joint'``) or the
independent two-stage estimator (``pooling='independent'``) based on the
``pooling`` argument. All MCMC configuration is forwarded through
``**kwargs`` to :class:`~pytyche.bcf.config.GPUBCFConfig`.
Args:
observed: Observed experiment data. Must have at least two variants.
pooling: ``'joint'`` (default) — shared-tree MCMC estimator (canonical);
``'independent'`` — two-stage binary + continuous estimator.
observed_copy: Copy mode for stashing the observed data on the result.
``'view'`` (default) — zero-copy read-only view; ``'deep'`` —
independent deep copy; ``'ref'`` — identity (no copy).
calibration: Optional SBC-fitted artifact applied post-fit —
sugar for ``.apply_calibration(calibration)`` on the returned
result (one application path; the kwarg adds no numerics).
``None`` (default) returns the raw posterior.
seed: Random seed forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig`
as ``random_seed``.
progress: When True, the joint fit renders tqdm progress bars on
stderr (GFR warm-start sweeps, MCMC chunks). The default False
is fully silent. ``pooling='independent'`` accepts the flag
but has no per-phase reporting hooks — its channel fits run
inside JIT-compiled loops.
**kwargs: Additional keyword arguments forwarded to
:class:`~pytyche.bcf.config.GPUBCFConfig`. Unknown keys raise
``TypeError`` (frozen dataclass with no extra fields).
Returns:
:class:`~pytyche.bcf.config.HurdleBCFResult` with ``pooling`` and
``observed`` set to the stashed snapshot; calibrated iff
``calibration`` was supplied.
Raises:
ValueError: If ``pooling`` is not ``'joint'`` or ``'independent'``,
if any variant contributes zero rows (missing treatment
level), or — propagated from ``apply_calibration`` — if a
supplied ``calibration``'s fitted regime (metric,
n_treatments, pooling) does not match this fit.
NotImplementedError: If ``pooling='independent'`` and K >= 3 (use
``pooling='joint'`` for multi-arm experiments).
TypeError: If ``**kwargs`` contains unknown :class:`~pytyche.bcf.config.GPUBCFConfig`
field names.
"""
_warn_if_no_cuda()
# --- Validate pooling ---
if pooling not in ("joint", "independent"):
raise ValueError(
f"pooling must be 'joint' or 'independent'; got {pooling!r}."
)
# --- Extract arrays from observed ---
extract = extract_fit_arrays(observed)
require_feature_columns(extract)
# K is the number of variants (arms), not inferred from Z.max().
# Using len(observed.variants) catches the case where a variant contributes
# zero rows — extract.Z would lack that level even though the arm exists.
K = len(observed.variants)
# --- Missing-level guard (fires before config construction / MCMC) ---
observed_levels = sorted(int(v) for v in np.unique(extract.Z.astype(int)).tolist())
expected_range = list(range(K))
if observed_levels != expected_range:
raise ValueError(
f"Every treatment level 0..{K - 1} must appear in the observed data. "
f"Observed levels: {observed_levels}; expected range: 0..{K - 1}. "
"A missing level means a variant contributed zero rows — "
"re-index or remove empty variants."
)
# --- Independent-pooling multi-arm guard ---
if pooling == "independent" and K >= 3:
raise NotImplementedError(
f"pooling='independent' supports only K=2 (binary treatment); got K={K}. "
"Use pooling='joint' for multi-arm (K >= 3) experiments."
)
# --- Validate observed_copy before config construction ---
if observed_copy not in _VALID_COPY_MODES:
raise ValueError(
f"observed_copy must be one of {_VALID_COPY_MODES!r}; got {observed_copy!r}."
)
# Narrow to the Literal type that freeze_observed expects.
_copy_mode = observed_copy # pyright: ignore[reportAssignmentType]
copy_mode_lit: Literal["view", "deep", "ref"] = _copy_mode # type: ignore[assignment]
# --- Build config (TypeError on unknown kwargs falls out naturally) ---
config = GPUBCFConfig(random_seed=seed, **kwargs)
# --- Freeze observed snapshot ---
frozen = freeze_observed(observed, copy_mode_lit)
# --- Dispatch to private raw-array core ---
if pooling == "joint":
result = _fit_joint_hurdle_bcf(
extract.X, extract.Z, extract.Y, extract.propensity, config,
progress=progress,
)
# Joint core already stamps pooling='joint'; stash observed on top.
result = dataclasses.replace(result, observed=frozen)
if calibration is not None:
result = result.apply_calibration(calibration)
return result
# pooling == "independent"
d = _fit_independent_hurdle_bcf(
extract.X, extract.Z, extract.Y, extract.propensity, config,
)
# Extract continuous sub-result for tau0 derivation.
# tau0 = 1 / sigma2_orig where sigma2_orig = sigma2 * y_std^2.
# This is the global severity noise precision — the same quantity the joint
# model estimates as tau_0 — so it is honest, not invented.
cont = d["continuous_result"]
y_std = float(cont.y_std)
sigma2_orig = np.asarray(cont.sigma2_samples) * (y_std * y_std)
tau0_out = (1.0 / np.maximum(sigma2_orig, 1e-30)).astype(np.float32)
# Per-draw channel arrays
p0_samples = d["p0_samples"] # (n, S)
p1_samples = d["p1_samples"] # (n, S)
sev0_samples = d["sev0_samples"] # (n, S)
sev1_samples = d["sev1_samples"] # (n, S)
# Paired means (per-draw means of channel arrays — cheap, honest)
p0_mean = p0_samples.mean(axis=1)
p1_mean = p1_samples.mean(axis=1)
sev0_mean = sev0_samples.mean(axis=1)
sev1_mean = sev1_samples.mean(axis=1)
result = HurdleBCFResult(
rpv_cate_samples=np.asarray(d["rpv_cate_samples"], dtype=np.float32),
p0_mean=p0_mean,
p1_mean=p1_mean,
sev0_mean=sev0_mean,
sev1_mean=sev1_mean,
tau0_samples=tau0_out,
tau_hat_quantiles=None,
wall_clock_seconds=float(d["wall_clock_seconds"]),
num_chains=config.num_chains,
num_gfr_sweeps=config.num_gfr_sweeps,
diagnostics=None,
phase_timing=None,
p0_samples=p0_samples,
p1_samples=p1_samples,
sev0_samples=sev0_samples,
sev1_samples=sev1_samples,
p_samples=None,
sev_samples=None,
topology_history=None,
observed=frozen,
is_calibrated=False,
pooling="independent",
)
if calibration is not None:
result = result.apply_calibration(calibration)
return result