Source code for pytyche.bcf.binary

"""Phase 2 — independent binary/probit BCF (Albert-Chib data augmentation).

Two-forest Bayesian Causal Forest for binary outcomes, executed on GPU via
bartz's JAX BART sampler with Albert-Chib probit data augmentation. Maintains
a prognostic forest (``mu``) and a treatment-effect forest (``tau``) updated
by interleaved backfitting sweeps, with the latent ``z`` resampled from a
one-sided truncated normal conditional on the binary observation after each
mu/tau sweep.

The MCMC loop is compiled as a single XLA program via ``lax.scan`` — no
Python dispatch between iterations. Error variance is fixed at 1 (probit
unit-variance identification); the only stochastic update beyond the trees
is the latent-z step.

Public API
----------
``fit_binary_bcf`` — observed-based entry point returning ``BinaryBCFResult``.

Internal helpers
----------------
``_fit_binary_bcf``         — raw-array private core (called by the public
                              wrapper and by internal callers such as
                              ``pytyche.bcf.hurdle.compose``).
``_init_mu_forest_binary``  — initialize prognostic forest state on float32
                              latent z (bartz disallows ``error_scale`` on
                              bool y, so we drive z ourselves).
``_init_tau_forest_binary`` — initialize treatment-effect forest with
                              compensating ``leaf_prior_cov_inv`` (4×) since
                              we skip the basis-weighted ``error_scale`` on
                              the binary path.
``_run_binary_bcf_loop``    — host-side wrapper that prepares inputs and
                              dispatches to the JIT'd scan.
``_binary_bcf_scan``        — JIT'd two-forest MCMC with embedded
                              Albert-Chib z-update via
                              ``truncated_normal_onesided``.

Import graph
------------
``pytyche.bcf.binary`` imports configuration types from ``pytyche.bcf.config``
and the shared preprocessing helpers ``_compute_basis`` /
``_preprocess_covariates`` from ``pytyche.bcf.preprocess``. After the preprocess
extraction at Stage C, the lazy-import dance is gone: helpers are imported at
module top with no circular dependency.
"""

from __future__ import annotations

import dataclasses
import time
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Literal

import jax
import jax.numpy as jnp
import numpy as np
from bartz.grove import evaluate_forest
from bartz.mcmcstep import (
    OutcomeType,
    init,
    make_p_nonterminal,
)
from jax import lax, random

from pytyche._internal.extraction import extract_fit_arrays, require_feature_columns
from pytyche._internal.observed import freeze_observed
from pytyche.bcf._bartz_compat import step_trees, truncated_normal_onesided
from pytyche.bcf.config import BinaryBCFResult, GPUBCFConfig
from pytyche.bcf.preprocess import _compute_basis, _preprocess_covariates
from pytyche.setup import _warn_if_no_cuda

if TYPE_CHECKING:
    from pytyche.calibrate.artifact import Calibration
    from pytyche.contracts import ObservedExperimentData

_VALID_COPY_MODES: tuple[str, ...] = ("view", "deep", "ref")

# ---------------------------------------------------------------------------
# Phase 2: Binary (probit) BCF — public observed-based wrapper
# ---------------------------------------------------------------------------


[docs] def fit_binary_bcf( observed: ObservedExperimentData, *, observed_copy: Literal["view", "deep", "ref"] = "view", calibration: Calibration | None = None, seed: int = 0, **kwargs, ) -> BinaryBCFResult: """Fit binary BCF (Albert-Chib probit) on observed experiment data. Accepts an :class:`~pytyche.contracts.ObservedExperimentData` and dispatches to the private raw-array core ``_fit_binary_bcf``. Args: observed: Observed experiment data. Must have a binary outcome column (metric ``'conversion_rate'``); all outcome values must be in ``{0, 1}`` after float cast. observed_copy: Copy mode for stashing the observed data on the result. ``'view'`` (default) — zero-copy read-only view; ``'deep'`` — independent deep copy; ``'ref'`` — identity (no copy). calibration: Optional SBC-fitted artifact applied post-fit — sugar for ``.apply_calibration(calibration)`` on the returned result (one application path; the kwarg adds no numerics). ``None`` (default) returns the raw posterior. seed: Random seed forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig` as ``random_seed``. **kwargs: Additional keyword arguments forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig`. Returns: :class:`~pytyche.bcf.config.BinaryBCFResult` with ``observed`` set to the stashed snapshot; calibrated iff ``calibration`` was supplied. Raises: ValueError: If any variant contributes zero rows (missing treatment level), if the extracted outcome contains values outside ``{0, 1}``, or — propagated from ``apply_calibration`` — if a supplied ``calibration``'s fitted regime does not match the observed data. NotImplementedError: If K >= 3 (use ``fit_hurdle_bcf`` for multi-arm). TypeError: If ``**kwargs`` contains unknown :class:`~pytyche.bcf.config.GPUBCFConfig` field names. """ _warn_if_no_cuda() # --- Extract arrays from observed --- extract = extract_fit_arrays(observed) require_feature_columns(extract) K = len(observed.variants) # --- Missing-level guard (before config construction) --- observed_levels = sorted(int(v) for v in np.unique(extract.Z.astype(int)).tolist()) expected_range = list(range(K)) if observed_levels != expected_range: raise ValueError( f"Every treatment level 0..{K - 1} must appear in the observed data. " f"Observed levels: {observed_levels}; expected range: 0..{K - 1}. " "A missing level means a variant contributed zero rows — " "re-index or remove empty variants." ) # --- Binary outcome guard (after extraction) --- y_float = extract.Y.astype(float) unique_vals = set(np.unique(y_float).tolist()) if not unique_vals.issubset({0.0, 1.0}): raise ValueError( f"fit_binary_bcf requires a binary outcome: all extracted Y values " f"must be in {{0, 1}}. Got values outside this set " f"(unique values: {sorted(unique_vals)!r}). " "Use fit_continuous_bcf for continuous outcomes or fit_hurdle_bcf " "for zero-inflated (hurdle) outcomes." ) # --- Validate observed_copy --- if observed_copy not in _VALID_COPY_MODES: raise ValueError( f"observed_copy must be one of {_VALID_COPY_MODES!r}; got {observed_copy!r}." ) _copy_mode = observed_copy # pyright: ignore[reportAssignmentType] copy_mode_lit: Literal["view", "deep", "ref"] = _copy_mode # type: ignore[assignment] # --- Build config --- config = GPUBCFConfig(random_seed=seed, **kwargs) # --- Freeze observed snapshot --- frozen = freeze_observed(observed, copy_mode_lit) # --- Dispatch to private raw-array core --- result = _fit_binary_bcf(extract.X, extract.Z, extract.Y, extract.propensity, config) # --- Attach stash --- result = dataclasses.replace(result, observed=frozen) if calibration is not None: result = result.apply_calibration(calibration) return result
# --------------------------------------------------------------------------- # Phase 2: Binary (probit) BCF — private raw-array core # --------------------------------------------------------------------------- def _fit_binary_bcf( X: np.ndarray, Z: np.ndarray, y_binary: np.ndarray, propensity: np.ndarray, config: GPUBCFConfig, ) -> BinaryBCFResult: """Fit binary BCF (Albert-Chib probit) on GPU via bartz. Private raw-array core. Public surface is :func:`fit_binary_bcf` (observed-based wrapper). Parameters ---------- X : (n, p) covariate matrix Z : (n,) treatment indicator (0/1) y_binary : (n,) binary outcome (bool or 0/1) propensity : (n,) propensity scores (reserved for interface compat) config : GPUBCFConfig """ start = time.time() K = int(np.asarray(Z).max()) + 1 if K >= 3: raise NotImplementedError( "_fit_binary_bcf supports only binary treatment (Z in {0, 1}); " f"got K={K} treatment levels. Multi-arm (K>=3) is available via " "fit_hurdle_bcf(pooling='joint')." ) basis = _compute_basis(Z) X_binned, max_split, _splits = _preprocess_covariates(X, config.num_cuts) y_bool = np.asarray(y_binary, dtype=bool) # Initialize latent z for probit: +0.5 for y=True, -0.5 for y=False z_init = np.where(y_bool, 0.5, -0.5).astype(np.float32) # Initialize mu and tau forests as continuous (float32 z) # We handle probit z-sampling ourselves in the custom loop. # Copy inputs — init() donates buffers. X_mu, ms_mu = jnp.array(X_binned), jnp.array(max_split) X_tau, ms_tau = jnp.array(X_binned), jnp.array(max_split) mu_state = _init_mu_forest_binary(X_mu, ms_mu, z_init, config) tau_state = _init_tau_forest_binary(X_tau, ms_tau, z_init, basis, config) # Run probit BCF loop mu_samples, tau_samples = _run_binary_bcf_loop( mu_state, tau_state, y_bool, basis, config, ) elapsed = time.time() - start return BinaryBCFResult( mu_samples=mu_samples, tau_samples=tau_samples, wall_clock_seconds=elapsed, ) # --------------------------------------------------------------------------- # Forest initialization helpers # --------------------------------------------------------------------------- def _init_mu_forest_binary( X_binned: jax.Array, max_split: jax.Array, z_init: np.ndarray, config: GPUBCFConfig, ): """Initialize mu forest for binary (probit) outcome. Uses float32 y (latent z), not bool — bartz doesn't allow error_scale on bool outcomes, and we need error_scale on the tau forest. We handle probit z-sampling ourselves in the custom loop. """ return init( X=X_binned, y=jnp.asarray(z_init, dtype=jnp.float32), outcome_type=OutcomeType.continuous, offset=jnp.float32(0.0), max_split=max_split, num_trees=config.num_trees_mu, p_nonterminal=make_p_nonterminal(config.max_depth, config.alpha_mu, config.beta_mu), leaf_prior_cov_inv=jnp.float32(config.num_trees_mu), error_cov_df=jnp.float32(3.0), error_cov_scale=jnp.float32(3.0), ) def _init_tau_forest_binary( X_binned: jax.Array, max_split: jax.Array, z_init: np.ndarray, basis: np.ndarray, config: GPUBCFConfig, ): """Initialize tau forest for binary (probit) outcome. Option A: no error_scale (prec_scale stays None = 1.0). Since bartz asserts error_scale is None for bool y, and we use float32 z as a workaround, skipping error_scale is simplest. Without prec_scale=0.25, data gets 4× more relative weight vs prior. Compensate by increasing leaf_prior_cov_inv by 4× to maintain the same prior-data balance as the continuous tau forest. """ # Continuous tau: leaf_prior_inv = num_trees_tau * 4.0, prec_scale = 0.25 # Binary tau: leaf_prior_inv = num_trees_tau * 16.0, prec_scale = None (1.0) # These give the same effective regularization ratio. leaf_prior_inv = float(config.num_trees_tau) * 16.0 return init( X=X_binned, y=jnp.asarray(z_init, dtype=jnp.float32), outcome_type=OutcomeType.continuous, offset=jnp.float32(0.0), max_split=max_split, num_trees=config.num_trees_tau, p_nonterminal=make_p_nonterminal(config.max_depth, config.alpha_tau, config.beta_tau), leaf_prior_cov_inv=jnp.float32(leaf_prior_inv), error_cov_df=jnp.float32(3.0), error_cov_scale=jnp.float32(3.0), ) # --------------------------------------------------------------------------- # MCMC loops — JIT'd via lax.scan for GPU performance # --------------------------------------------------------------------------- def _run_binary_bcf_loop( mu_state, tau_state, y_bool: np.ndarray, basis: np.ndarray, config: GPUBCFConfig, ) -> tuple[np.ndarray, np.ndarray]: """Wrapper that dispatches to JIT'd binary BCF loop.""" basis_jax = jnp.asarray(basis, dtype=jnp.float32) y_jax = jnp.asarray(y_bool, dtype=bool) key = random.key(config.random_seed) mu_samples, tau_samples = _binary_bcf_scan( mu_state, tau_state, y_jax, basis_jax, key, config.num_burnin, config.num_mcmc, ) return np.array(mu_samples).T, np.array(tau_samples).T @partial(jax.jit, static_argnums=(5, 6), donate_argnums=(0, 1)) def _binary_bcf_scan(mu_state, tau_state, y_bool, basis, key, n_burnin, n_mcmc): """JIT'd two-forest MCMC for binary (probit) BCF via lax.scan. Same as continuous, but error_cov_inv is fixed at 1.0 (probit unit variance) and latent z is sampled via truncated normal after each sweep. """ sigma2_inv = jnp.float32(1.0) # fixed for probit z = jnp.where(y_bool, jnp.float32(0.5), jnp.float32(-0.5)) def step(mu_s, tau_s, z, key, y, basis): key, k_mu, k_tau, k_z = random.split(key, 4) mu_pred = evaluate_forest(mu_s.X, mu_s.forest, sum_batch_axis=0) tau_pred = evaluate_forest(tau_s.X, tau_s.forest, sum_batch_axis=0) # 1. Mu sweep: resid = (z - tau*basis) - mu_sum mu_resid = z - tau_pred * basis - mu_pred mu_s = replace(mu_s, resid=mu_resid, error_cov_inv=sigma2_inv) mu_s = step_trees(k_mu, mu_s) # 2. Tau sweep: resid = (z - mu_new)/basis - tau_sum mu_pred_new = evaluate_forest(mu_s.X, mu_s.forest, sum_batch_axis=0) tau_resid = (z - mu_pred_new) / basis - tau_pred tau_s = replace(tau_s, resid=tau_resid, error_cov_inv=sigma2_inv) tau_s = step_trees(k_tau, tau_s) # 3. Sample latent z via truncated normal tau_pred_new = evaluate_forest(tau_s.X, tau_s.forest, sum_batch_axis=0) combined_pred = mu_pred_new + tau_pred_new * basis resid = truncated_normal_onesided(k_z, (), ~y, -combined_pred) z = combined_pred + resid return mu_s, tau_s, z, key, mu_pred_new, tau_pred_new # Burnin phase def burnin_body(carry, _): mu_s, tau_s, z, k = carry mu_s, tau_s, z, k, _, _ = step(mu_s, tau_s, z, k, y_bool, basis) return (mu_s, tau_s, z, k), None carry = (mu_state, tau_state, z, key) carry, _ = lax.scan(burnin_body, carry, None, length=n_burnin) # MCMC phase def mcmc_body(carry, _): mu_s, tau_s, z, k = carry mu_s, tau_s, z, k, mu_pred, tau_pred = step( mu_s, tau_s, z, k, y_bool, basis, ) return (mu_s, tau_s, z, k), (mu_pred, tau_pred) carry, (mu_samples, tau_samples) = lax.scan( mcmc_body, carry, None, length=n_mcmc, ) return mu_samples, tau_samples