"""Auto-selecting BCF entry point — ``pytyche.fit``.
Inspects the extracted outcome array from an
:class:`~pytyche.contracts.ObservedExperimentData` and dispatches to the
appropriate BCF fit function:
- ``fit_binary_bcf`` — Y ⊆ {0, 1} (conversion-rate metric)
- ``fit_hurdle_bcf`` — float Y with ≥ 30% zeros and a positive non-zero tail
- ``fit_continuous_bcf``— float Y with < 30% zeros (all-positive revenue etc.)
The 30% threshold was chosen to match e-commerce revenue zero-inflation
characteristics: at zero share below 30% the hurdle model's binary sub-model
has too few signal events to fit reliably, and a continuous model on the
positive-heavy distribution gives equivalent results. The threshold is an
internal tunable (``_HURDLE_THRESHOLD``); it is not yet user-facing — expose
as a kwarg if user feedback warrants.
Decision table::
Y all-zero → ValueError (ambiguous: binary-all-zero vs hurdle-all-zero)
Y ⊆ {0.0, 1.0}, K=2 → fit_binary_bcf
Y ⊆ {0.0, 1.0}, K≥3 → NotImplementedError (reference: multi-arm-joint-hurdle-bcf)
zero_share ≥ 0.30, K≥2 → fit_hurdle_bcf (any K)
else, K=2 → fit_continuous_bcf
else, K≥3 → NotImplementedError (reference: multi-arm-joint-hurdle-bcf)
Public surface
--------------
``_dispatch_fit(observed)`` — pure function; returns the selected callable.
``fit(observed, ...)`` — calls _dispatch_fit then forwards to the result.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
from pytyche._internal.extraction import extract_fit_arrays
from pytyche.bcf.binary import fit_binary_bcf
from pytyche.bcf.continuous import fit_continuous_bcf
from pytyche.bcf.hurdle import fit_hurdle_bcf
if TYPE_CHECKING:
from pytyche.bcf.config import BinaryBCFResult, ContinuousBCFResult, HurdleBCFResult
from pytyche.calibrate.artifact import Calibration
from pytyche.contracts import ObservedExperimentData
# Internal tunable — not user-facing. The 30% threshold matches e-commerce
# revenue zero-inflation: below this share the hurdle binary sub-model has too
# few signal events and a continuous model is equivalent. Expose as a kwarg
# if user feedback warrants.
_HURDLE_THRESHOLD: float = 0.30
def _dispatch_fit(observed: ObservedExperimentData) -> Callable[..., Any]:
"""Return the BCF fit function appropriate for *observed*.
Pure function — performs no MCMC. Dispatches based on the extracted
outcome array Y (via :func:`~pytyche._internal.extraction.extract_fit_arrays`)
and the number of variants K.
An unknown metric raises the adapter's :class:`ValueError` naturally (the
dispatch inspects the extracted Y, so extraction runs first and surfaces
any metric errors before dispatch logic runs).
Decision table:
- Y all-zero → :class:`ValueError` (ambiguous between binary-all-zero and
hurdle-all-zero).
- Y ⊆ {0.0, 1.0} and K=2 → :func:`~pytyche.bcf.binary.fit_binary_bcf`.
- Y ⊆ {0.0, 1.0} and K≥3 → :class:`NotImplementedError` referencing
``multi-arm-joint-hurdle-bcf``.
- ``zero_share = mean(Y == 0) >= _HURDLE_THRESHOLD`` and non-zero tail has
positive values → :func:`~pytyche.bcf.hurdle.fit_hurdle_bcf` (any K).
- else and K=2 → :func:`~pytyche.bcf.continuous.fit_continuous_bcf`.
- else and K≥3 → :class:`NotImplementedError` referencing
``multi-arm-joint-hurdle-bcf``.
Args:
observed: Observed experiment data.
Returns:
One of :func:`fit_binary_bcf`, :func:`fit_hurdle_bcf`, or
:func:`fit_continuous_bcf` — the same object as the public wrapper
(identity-checked by tests).
Raises:
ValueError: Unknown metric (propagated from the extraction adapter).
ValueError: Y is all-zero (ambiguous dispatch).
NotImplementedError: K≥3 with binary or continuous Y (only multi-arm
hurdle ships in v0.2).
"""
extract = extract_fit_arrays(observed)
Y = extract.Y # float array, shape (n,)
K = len(observed.variants)
# --- All-zero guard (ambiguous dispatch) ---
if np.all(Y == 0.0):
raise ValueError(
"All outcome values are zero — dispatch is ambiguous. This could "
"be a binary experiment where every visitor had zero conversions, "
"or a hurdle/revenue experiment where every visitor spent nothing. "
"Verify the metric and data: 'binary' (conversion_rate) all-zero "
"means no conversions were recorded; 'hurdle' (revenue_per_visitor) "
"all-zero means no revenue was recorded. Re-check your data before "
"fitting."
)
# --- Binary check: Y ⊆ {0.0, 1.0} ---
unique_vals = np.unique(Y)
is_binary = np.all(np.isin(unique_vals, [0.0, 1.0]))
if is_binary:
if K >= 3:
raise NotImplementedError(
f"Binary outcome with K={K} variants is not yet supported by "
"fit_binary_bcf (K=2 only). "
"The upcoming multi-arm-joint-hurdle-bcf release will handle "
"multi-arm binary and continuous experiments. "
"For now, use a K=2 (control vs. single treatment) design."
)
return fit_binary_bcf
# --- Hurdle check: zero_share >= threshold with positive non-zero tail ---
zero_share = float(np.mean(Y == 0.0))
has_positive_nonzero = bool(np.any(Y > 0.0))
if zero_share >= _HURDLE_THRESHOLD and has_positive_nonzero:
# Hurdle dispatches for any K (multi-arm joint hurdle ships in v0.2).
return fit_hurdle_bcf
# --- Continuous ---
if K >= 3:
raise NotImplementedError(
f"Continuous outcome with K={K} variants is not yet supported by "
"fit_continuous_bcf (K=2 only). "
"The upcoming multi-arm-joint-hurdle-bcf release will handle "
"multi-arm binary and continuous experiments. "
"For now, use a K=2 (control vs. single treatment) design."
)
return fit_continuous_bcf
[docs]
def fit(
observed: ObservedExperimentData,
*,
observed_copy: Literal["view", "deep", "ref"] = "view",
calibration: Calibration | None = None,
seed: int = 0,
**kwargs: Any,
) -> BinaryBCFResult | ContinuousBCFResult | HurdleBCFResult:
"""Fit BCF on *observed*, auto-selecting the model from the outcome shape.
Dispatches to :func:`~pytyche.bcf.binary.fit_binary_bcf`,
:func:`~pytyche.bcf.continuous.fit_continuous_bcf`, or
:func:`~pytyche.bcf.hurdle.fit_hurdle_bcf` based on the outcome array
extracted from *observed* (see ``_dispatch_fit`` for the full decision
table).
``**kwargs`` are forwarded verbatim to the selected fit function.
This means:
- ``pooling=`` reaches :func:`~pytyche.bcf.hurdle.fit_hurdle_bcf` without issue.
- ``pooling=`` passed when binary or continuous is dispatched raises
:class:`TypeError` from :class:`~pytyche.bcf.config.GPUBCFConfig`
(that is a caller error; the error is loud and intentional).
Implementation note: ``fit`` calls ``_dispatch_fit``, which runs
the ``extract_fit_arrays`` adapter to inspect Y;
then the selected wrapper runs ``extract_fit_arrays`` again internally.
That double extraction is deliberate for v0.2 — extraction is cheap
relative to MCMC and one-correct-path beats plumbing an extraction bypass.
A future optimizer can thread the arrays through if profiling shows cost.
Args:
observed: Observed experiment data.
observed_copy: Copy mode for stashing the observed data on the result.
``'view'`` (default) — zero-copy read-only view; ``'deep'`` —
independent deep copy; ``'ref'`` — identity (no copy).
calibration: Optional SBC-fitted artifact applied post-fit by the
selected entry point — sugar for ``.apply_calibration(...)``
on the result. ``None`` (default) returns the raw posterior.
seed: Random seed forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig`.
**kwargs: Additional keyword arguments forwarded verbatim to the
selected fit function (e.g., ``num_mcmc``, ``num_burnin``,
``num_trees_mu``, ``pooling``).
Returns:
:class:`~pytyche.bcf.config.BinaryBCFResult`,
:class:`~pytyche.bcf.config.ContinuousBCFResult`, or
:class:`~pytyche.bcf.config.HurdleBCFResult` — whichever the
selected fit function returns.
Raises:
ValueError: Unknown metric or all-zero Y (ambiguous dispatch); or,
propagated from the selected fit, a supplied ``calibration``
whose fitted regime does not match the observed data.
NotImplementedError: K≥3 with binary or continuous Y.
TypeError: If ``**kwargs`` contains unknown
:class:`~pytyche.bcf.config.GPUBCFConfig` field names, or if
a kwarg collides with an explicit parameter of this function
(standard Python duplicate-keyword-argument error).
"""
selected = _dispatch_fit(observed)
return selected(
observed,
observed_copy=observed_copy,
calibration=calibration,
seed=seed,
**kwargs,
)