Source code for pytyche.bcf.dispatch

"""Auto-selecting BCF entry point — ``pytyche.fit``.

Inspects the extracted outcome array from an
:class:`~pytyche.contracts.ObservedExperimentData` and dispatches to the
appropriate BCF fit function:

- ``fit_binary_bcf``    — Y ⊆ {0, 1} (conversion-rate metric)
- ``fit_hurdle_bcf``    — float Y with ≥ 30% zeros and a positive non-zero tail
- ``fit_continuous_bcf``— float Y with < 30% zeros (all-positive revenue etc.)

The 30% threshold was chosen to match e-commerce revenue zero-inflation
characteristics: at zero share below 30% the hurdle model's binary sub-model
has too few signal events to fit reliably, and a continuous model on the
positive-heavy distribution gives equivalent results.  The threshold is an
internal tunable (``_HURDLE_THRESHOLD``); it is not yet user-facing — expose
as a kwarg if user feedback warrants.

Decision table::

    Y all-zero               → ValueError (ambiguous: binary-all-zero vs hurdle-all-zero)
    Y ⊆ {0.0, 1.0}, K=2     → fit_binary_bcf
    Y ⊆ {0.0, 1.0}, K≥3     → NotImplementedError (reference: multi-arm-joint-hurdle-bcf)
    zero_share ≥ 0.30, K≥2  → fit_hurdle_bcf  (any K)
    else, K=2                → fit_continuous_bcf
    else, K≥3                → NotImplementedError (reference: multi-arm-joint-hurdle-bcf)

Public surface
--------------
``_dispatch_fit(observed)`` — pure function; returns the selected callable.
``fit(observed, ...)``      — calls _dispatch_fit then forwards to the result.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal

import numpy as np

from pytyche._internal.extraction import extract_fit_arrays
from pytyche.bcf.binary import fit_binary_bcf
from pytyche.bcf.continuous import fit_continuous_bcf
from pytyche.bcf.hurdle import fit_hurdle_bcf

if TYPE_CHECKING:
    from pytyche.bcf.config import BinaryBCFResult, ContinuousBCFResult, HurdleBCFResult
    from pytyche.calibrate.artifact import Calibration
    from pytyche.contracts import ObservedExperimentData

# Internal tunable — not user-facing.  The 30% threshold matches e-commerce
# revenue zero-inflation: below this share the hurdle binary sub-model has too
# few signal events and a continuous model is equivalent.  Expose as a kwarg
# if user feedback warrants.
_HURDLE_THRESHOLD: float = 0.30


def _dispatch_fit(observed: ObservedExperimentData) -> Callable[..., Any]:
    """Return the BCF fit function appropriate for *observed*.

    Pure function — performs no MCMC.  Dispatches based on the extracted
    outcome array Y (via :func:`~pytyche._internal.extraction.extract_fit_arrays`)
    and the number of variants K.

    An unknown metric raises the adapter's :class:`ValueError` naturally (the
    dispatch inspects the extracted Y, so extraction runs first and surfaces
    any metric errors before dispatch logic runs).

    Decision table:

    - Y all-zero → :class:`ValueError` (ambiguous between binary-all-zero and
      hurdle-all-zero).
    - Y ⊆ {0.0, 1.0} and K=2 → :func:`~pytyche.bcf.binary.fit_binary_bcf`.
    - Y ⊆ {0.0, 1.0} and K≥3 → :class:`NotImplementedError` referencing
      ``multi-arm-joint-hurdle-bcf``.
    - ``zero_share = mean(Y == 0) >= _HURDLE_THRESHOLD`` and non-zero tail has
      positive values → :func:`~pytyche.bcf.hurdle.fit_hurdle_bcf` (any K).
    - else and K=2 → :func:`~pytyche.bcf.continuous.fit_continuous_bcf`.
    - else and K≥3 → :class:`NotImplementedError` referencing
      ``multi-arm-joint-hurdle-bcf``.

    Args:
        observed: Observed experiment data.

    Returns:
        One of :func:`fit_binary_bcf`, :func:`fit_hurdle_bcf`, or
        :func:`fit_continuous_bcf` — the same object as the public wrapper
        (identity-checked by tests).

    Raises:
        ValueError: Unknown metric (propagated from the extraction adapter).
        ValueError: Y is all-zero (ambiguous dispatch).
        NotImplementedError: K≥3 with binary or continuous Y (only multi-arm
            hurdle ships in v0.2).
    """
    extract = extract_fit_arrays(observed)
    Y = extract.Y  # float array, shape (n,)
    K = len(observed.variants)

    # --- All-zero guard (ambiguous dispatch) ---
    if np.all(Y == 0.0):
        raise ValueError(
            "All outcome values are zero — dispatch is ambiguous.  This could "
            "be a binary experiment where every visitor had zero conversions, "
            "or a hurdle/revenue experiment where every visitor spent nothing.  "
            "Verify the metric and data: 'binary' (conversion_rate) all-zero "
            "means no conversions were recorded; 'hurdle' (revenue_per_visitor) "
            "all-zero means no revenue was recorded.  Re-check your data before "
            "fitting."
        )

    # --- Binary check: Y ⊆ {0.0, 1.0} ---
    unique_vals = np.unique(Y)
    is_binary = np.all(np.isin(unique_vals, [0.0, 1.0]))

    if is_binary:
        if K >= 3:
            raise NotImplementedError(
                f"Binary outcome with K={K} variants is not yet supported by "
                "fit_binary_bcf (K=2 only).  "
                "The upcoming multi-arm-joint-hurdle-bcf release will handle "
                "multi-arm binary and continuous experiments.  "
                "For now, use a K=2 (control vs. single treatment) design."
            )
        return fit_binary_bcf

    # --- Hurdle check: zero_share >= threshold with positive non-zero tail ---
    zero_share = float(np.mean(Y == 0.0))
    has_positive_nonzero = bool(np.any(Y > 0.0))

    if zero_share >= _HURDLE_THRESHOLD and has_positive_nonzero:
        # Hurdle dispatches for any K (multi-arm joint hurdle ships in v0.2).
        return fit_hurdle_bcf

    # --- Continuous ---
    if K >= 3:
        raise NotImplementedError(
            f"Continuous outcome with K={K} variants is not yet supported by "
            "fit_continuous_bcf (K=2 only).  "
            "The upcoming multi-arm-joint-hurdle-bcf release will handle "
            "multi-arm binary and continuous experiments.  "
            "For now, use a K=2 (control vs. single treatment) design."
        )
    return fit_continuous_bcf


[docs] def fit( observed: ObservedExperimentData, *, observed_copy: Literal["view", "deep", "ref"] = "view", calibration: Calibration | None = None, seed: int = 0, **kwargs: Any, ) -> BinaryBCFResult | ContinuousBCFResult | HurdleBCFResult: """Fit BCF on *observed*, auto-selecting the model from the outcome shape. Dispatches to :func:`~pytyche.bcf.binary.fit_binary_bcf`, :func:`~pytyche.bcf.continuous.fit_continuous_bcf`, or :func:`~pytyche.bcf.hurdle.fit_hurdle_bcf` based on the outcome array extracted from *observed* (see ``_dispatch_fit`` for the full decision table). ``**kwargs`` are forwarded verbatim to the selected fit function. This means: - ``pooling=`` reaches :func:`~pytyche.bcf.hurdle.fit_hurdle_bcf` without issue. - ``pooling=`` passed when binary or continuous is dispatched raises :class:`TypeError` from :class:`~pytyche.bcf.config.GPUBCFConfig` (that is a caller error; the error is loud and intentional). Implementation note: ``fit`` calls ``_dispatch_fit``, which runs the ``extract_fit_arrays`` adapter to inspect Y; then the selected wrapper runs ``extract_fit_arrays`` again internally. That double extraction is deliberate for v0.2 — extraction is cheap relative to MCMC and one-correct-path beats plumbing an extraction bypass. A future optimizer can thread the arrays through if profiling shows cost. Args: observed: Observed experiment data. observed_copy: Copy mode for stashing the observed data on the result. ``'view'`` (default) — zero-copy read-only view; ``'deep'`` — independent deep copy; ``'ref'`` — identity (no copy). calibration: Optional SBC-fitted artifact applied post-fit by the selected entry point — sugar for ``.apply_calibration(...)`` on the result. ``None`` (default) returns the raw posterior. seed: Random seed forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig`. **kwargs: Additional keyword arguments forwarded verbatim to the selected fit function (e.g., ``num_mcmc``, ``num_burnin``, ``num_trees_mu``, ``pooling``). Returns: :class:`~pytyche.bcf.config.BinaryBCFResult`, :class:`~pytyche.bcf.config.ContinuousBCFResult`, or :class:`~pytyche.bcf.config.HurdleBCFResult` — whichever the selected fit function returns. Raises: ValueError: Unknown metric or all-zero Y (ambiguous dispatch); or, propagated from the selected fit, a supplied ``calibration`` whose fitted regime does not match the observed data. NotImplementedError: K≥3 with binary or continuous Y. TypeError: If ``**kwargs`` contains unknown :class:`~pytyche.bcf.config.GPUBCFConfig` field names, or if a kwarg collides with an explicit parameter of this function (standard Python duplicate-keyword-argument error). """ selected = _dispatch_fit(observed) return selected( observed, observed_copy=observed_copy, calibration=calibration, seed=seed, **kwargs, )