"""Phase 2 — independent binary/probit BCF (Albert-Chib data augmentation).
Two-forest Bayesian Causal Forest for binary outcomes, executed on GPU via
bartz's JAX BART sampler with Albert-Chib probit data augmentation. Maintains
a prognostic forest (``mu``) and a treatment-effect forest (``tau``) updated
by interleaved backfitting sweeps, with the latent ``z`` resampled from a
one-sided truncated normal conditional on the binary observation after each
mu/tau sweep.
The MCMC loop is compiled as a single XLA program via ``lax.scan`` — no
Python dispatch between iterations. Error variance is fixed at 1 (probit
unit-variance identification); the only stochastic update beyond the trees
is the latent-z step.
Public API
----------
``fit_binary_bcf`` — observed-based entry point returning ``BinaryBCFResult``.
Internal helpers
----------------
``_fit_binary_bcf`` — raw-array private core (called by the public
wrapper and by internal callers such as
``pytyche.bcf.hurdle.compose``).
``_init_mu_forest_binary`` — initialize prognostic forest state on float32
latent z (bartz disallows ``error_scale`` on
bool y, so we drive z ourselves).
``_init_tau_forest_binary`` — initialize treatment-effect forest with
compensating ``leaf_prior_cov_inv`` (4×) since
we skip the basis-weighted ``error_scale`` on
the binary path.
``_run_binary_bcf_loop`` — host-side wrapper that prepares inputs and
dispatches to the JIT'd scan.
``_binary_bcf_scan`` — JIT'd two-forest MCMC with embedded
Albert-Chib z-update via
``truncated_normal_onesided``.
Import graph
------------
``pytyche.bcf.binary`` imports configuration types from ``pytyche.bcf.config``
and the shared preprocessing helpers ``_compute_basis`` /
``_preprocess_covariates`` from ``pytyche.bcf.preprocess``. After the preprocess
extraction at Stage C, the lazy-import dance is gone: helpers are imported at
module top with no circular dependency.
"""
from __future__ import annotations
import dataclasses
import time
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Literal
import jax
import jax.numpy as jnp
import numpy as np
from bartz.grove import evaluate_forest
from bartz.mcmcstep import (
OutcomeType,
init,
make_p_nonterminal,
)
from jax import lax, random
from pytyche._internal.extraction import extract_fit_arrays, require_feature_columns
from pytyche._internal.observed import freeze_observed
from pytyche.bcf._bartz_compat import step_trees, truncated_normal_onesided
from pytyche.bcf.config import BinaryBCFResult, GPUBCFConfig
from pytyche.bcf.preprocess import _compute_basis, _preprocess_covariates
from pytyche.setup import _warn_if_no_cuda
if TYPE_CHECKING:
from pytyche.calibrate.artifact import Calibration
from pytyche.contracts import ObservedExperimentData
_VALID_COPY_MODES: tuple[str, ...] = ("view", "deep", "ref")
# ---------------------------------------------------------------------------
# Phase 2: Binary (probit) BCF — public observed-based wrapper
# ---------------------------------------------------------------------------
[docs]
def fit_binary_bcf(
observed: ObservedExperimentData,
*,
observed_copy: Literal["view", "deep", "ref"] = "view",
calibration: Calibration | None = None,
seed: int = 0,
**kwargs,
) -> BinaryBCFResult:
"""Fit binary BCF (Albert-Chib probit) on observed experiment data.
Accepts an :class:`~pytyche.contracts.ObservedExperimentData` and
dispatches to the private raw-array core ``_fit_binary_bcf``.
Args:
observed: Observed experiment data. Must have a binary outcome column
(metric ``'conversion_rate'``); all outcome values must be in
``{0, 1}`` after float cast.
observed_copy: Copy mode for stashing the observed data on the result.
``'view'`` (default) — zero-copy read-only view; ``'deep'`` —
independent deep copy; ``'ref'`` — identity (no copy).
calibration: Optional SBC-fitted artifact applied post-fit —
sugar for ``.apply_calibration(calibration)`` on the returned
result (one application path; the kwarg adds no numerics).
``None`` (default) returns the raw posterior.
seed: Random seed forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig`
as ``random_seed``.
**kwargs: Additional keyword arguments forwarded to
:class:`~pytyche.bcf.config.GPUBCFConfig`.
Returns:
:class:`~pytyche.bcf.config.BinaryBCFResult` with ``observed`` set
to the stashed snapshot; calibrated iff ``calibration`` was
supplied.
Raises:
ValueError: If any variant contributes zero rows (missing treatment
level), if the extracted outcome contains values outside
``{0, 1}``, or — propagated from ``apply_calibration`` — if a
supplied ``calibration``'s fitted regime does not match the
observed data.
NotImplementedError: If K >= 3 (use ``fit_hurdle_bcf`` for multi-arm).
TypeError: If ``**kwargs`` contains unknown
:class:`~pytyche.bcf.config.GPUBCFConfig` field names.
"""
_warn_if_no_cuda()
# --- Extract arrays from observed ---
extract = extract_fit_arrays(observed)
require_feature_columns(extract)
K = len(observed.variants)
# --- Missing-level guard (before config construction) ---
observed_levels = sorted(int(v) for v in np.unique(extract.Z.astype(int)).tolist())
expected_range = list(range(K))
if observed_levels != expected_range:
raise ValueError(
f"Every treatment level 0..{K - 1} must appear in the observed data. "
f"Observed levels: {observed_levels}; expected range: 0..{K - 1}. "
"A missing level means a variant contributed zero rows — "
"re-index or remove empty variants."
)
# --- Binary outcome guard (after extraction) ---
y_float = extract.Y.astype(float)
unique_vals = set(np.unique(y_float).tolist())
if not unique_vals.issubset({0.0, 1.0}):
raise ValueError(
f"fit_binary_bcf requires a binary outcome: all extracted Y values "
f"must be in {{0, 1}}. Got values outside this set "
f"(unique values: {sorted(unique_vals)!r}). "
"Use fit_continuous_bcf for continuous outcomes or fit_hurdle_bcf "
"for zero-inflated (hurdle) outcomes."
)
# --- Validate observed_copy ---
if observed_copy not in _VALID_COPY_MODES:
raise ValueError(
f"observed_copy must be one of {_VALID_COPY_MODES!r}; got {observed_copy!r}."
)
_copy_mode = observed_copy # pyright: ignore[reportAssignmentType]
copy_mode_lit: Literal["view", "deep", "ref"] = _copy_mode # type: ignore[assignment]
# --- Build config ---
config = GPUBCFConfig(random_seed=seed, **kwargs)
# --- Freeze observed snapshot ---
frozen = freeze_observed(observed, copy_mode_lit)
# --- Dispatch to private raw-array core ---
result = _fit_binary_bcf(extract.X, extract.Z, extract.Y, extract.propensity, config)
# --- Attach stash ---
result = dataclasses.replace(result, observed=frozen)
if calibration is not None:
result = result.apply_calibration(calibration)
return result
# ---------------------------------------------------------------------------
# Phase 2: Binary (probit) BCF — private raw-array core
# ---------------------------------------------------------------------------
def _fit_binary_bcf(
X: np.ndarray,
Z: np.ndarray,
y_binary: np.ndarray,
propensity: np.ndarray,
config: GPUBCFConfig,
) -> BinaryBCFResult:
"""Fit binary BCF (Albert-Chib probit) on GPU via bartz.
Private raw-array core. Public surface is
:func:`fit_binary_bcf` (observed-based wrapper).
Parameters
----------
X : (n, p) covariate matrix
Z : (n,) treatment indicator (0/1)
y_binary : (n,) binary outcome (bool or 0/1)
propensity : (n,) propensity scores (reserved for interface compat)
config : GPUBCFConfig
"""
start = time.time()
K = int(np.asarray(Z).max()) + 1
if K >= 3:
raise NotImplementedError(
"_fit_binary_bcf supports only binary treatment (Z in {0, 1}); "
f"got K={K} treatment levels. Multi-arm (K>=3) is available via "
"fit_hurdle_bcf(pooling='joint')."
)
basis = _compute_basis(Z)
X_binned, max_split, _splits = _preprocess_covariates(X, config.num_cuts)
y_bool = np.asarray(y_binary, dtype=bool)
# Initialize latent z for probit: +0.5 for y=True, -0.5 for y=False
z_init = np.where(y_bool, 0.5, -0.5).astype(np.float32)
# Initialize mu and tau forests as continuous (float32 z)
# We handle probit z-sampling ourselves in the custom loop.
# Copy inputs — init() donates buffers.
X_mu, ms_mu = jnp.array(X_binned), jnp.array(max_split)
X_tau, ms_tau = jnp.array(X_binned), jnp.array(max_split)
mu_state = _init_mu_forest_binary(X_mu, ms_mu, z_init, config)
tau_state = _init_tau_forest_binary(X_tau, ms_tau, z_init, basis, config)
# Run probit BCF loop
mu_samples, tau_samples = _run_binary_bcf_loop(
mu_state, tau_state, y_bool, basis, config,
)
elapsed = time.time() - start
return BinaryBCFResult(
mu_samples=mu_samples,
tau_samples=tau_samples,
wall_clock_seconds=elapsed,
)
# ---------------------------------------------------------------------------
# Forest initialization helpers
# ---------------------------------------------------------------------------
def _init_mu_forest_binary(
X_binned: jax.Array,
max_split: jax.Array,
z_init: np.ndarray,
config: GPUBCFConfig,
):
"""Initialize mu forest for binary (probit) outcome.
Uses float32 y (latent z), not bool — bartz doesn't allow error_scale
on bool outcomes, and we need error_scale on the tau forest. We handle
probit z-sampling ourselves in the custom loop.
"""
return init(
X=X_binned,
y=jnp.asarray(z_init, dtype=jnp.float32),
outcome_type=OutcomeType.continuous,
offset=jnp.float32(0.0),
max_split=max_split,
num_trees=config.num_trees_mu,
p_nonterminal=make_p_nonterminal(config.max_depth, config.alpha_mu, config.beta_mu),
leaf_prior_cov_inv=jnp.float32(config.num_trees_mu),
error_cov_df=jnp.float32(3.0),
error_cov_scale=jnp.float32(3.0),
)
def _init_tau_forest_binary(
X_binned: jax.Array,
max_split: jax.Array,
z_init: np.ndarray,
basis: np.ndarray,
config: GPUBCFConfig,
):
"""Initialize tau forest for binary (probit) outcome.
Option A: no error_scale (prec_scale stays None = 1.0). Since bartz
asserts error_scale is None for bool y, and we use float32 z as a
workaround, skipping error_scale is simplest.
Without prec_scale=0.25, data gets 4× more relative weight vs prior.
Compensate by increasing leaf_prior_cov_inv by 4× to maintain the
same prior-data balance as the continuous tau forest.
"""
# Continuous tau: leaf_prior_inv = num_trees_tau * 4.0, prec_scale = 0.25
# Binary tau: leaf_prior_inv = num_trees_tau * 16.0, prec_scale = None (1.0)
# These give the same effective regularization ratio.
leaf_prior_inv = float(config.num_trees_tau) * 16.0
return init(
X=X_binned,
y=jnp.asarray(z_init, dtype=jnp.float32),
outcome_type=OutcomeType.continuous,
offset=jnp.float32(0.0),
max_split=max_split,
num_trees=config.num_trees_tau,
p_nonterminal=make_p_nonterminal(config.max_depth, config.alpha_tau, config.beta_tau),
leaf_prior_cov_inv=jnp.float32(leaf_prior_inv),
error_cov_df=jnp.float32(3.0),
error_cov_scale=jnp.float32(3.0),
)
# ---------------------------------------------------------------------------
# MCMC loops — JIT'd via lax.scan for GPU performance
# ---------------------------------------------------------------------------
def _run_binary_bcf_loop(
mu_state,
tau_state,
y_bool: np.ndarray,
basis: np.ndarray,
config: GPUBCFConfig,
) -> tuple[np.ndarray, np.ndarray]:
"""Wrapper that dispatches to JIT'd binary BCF loop."""
basis_jax = jnp.asarray(basis, dtype=jnp.float32)
y_jax = jnp.asarray(y_bool, dtype=bool)
key = random.key(config.random_seed)
mu_samples, tau_samples = _binary_bcf_scan(
mu_state, tau_state, y_jax, basis_jax, key,
config.num_burnin, config.num_mcmc,
)
return np.array(mu_samples).T, np.array(tau_samples).T
@partial(jax.jit, static_argnums=(5, 6), donate_argnums=(0, 1))
def _binary_bcf_scan(mu_state, tau_state, y_bool, basis, key, n_burnin, n_mcmc):
"""JIT'd two-forest MCMC for binary (probit) BCF via lax.scan.
Same as continuous, but error_cov_inv is fixed at 1.0 (probit unit
variance) and latent z is sampled via truncated normal after each sweep.
"""
sigma2_inv = jnp.float32(1.0) # fixed for probit
z = jnp.where(y_bool, jnp.float32(0.5), jnp.float32(-0.5))
def step(mu_s, tau_s, z, key, y, basis):
key, k_mu, k_tau, k_z = random.split(key, 4)
mu_pred = evaluate_forest(mu_s.X, mu_s.forest, sum_batch_axis=0)
tau_pred = evaluate_forest(tau_s.X, tau_s.forest, sum_batch_axis=0)
# 1. Mu sweep: resid = (z - tau*basis) - mu_sum
mu_resid = z - tau_pred * basis - mu_pred
mu_s = replace(mu_s, resid=mu_resid, error_cov_inv=sigma2_inv)
mu_s = step_trees(k_mu, mu_s)
# 2. Tau sweep: resid = (z - mu_new)/basis - tau_sum
mu_pred_new = evaluate_forest(mu_s.X, mu_s.forest, sum_batch_axis=0)
tau_resid = (z - mu_pred_new) / basis - tau_pred
tau_s = replace(tau_s, resid=tau_resid, error_cov_inv=sigma2_inv)
tau_s = step_trees(k_tau, tau_s)
# 3. Sample latent z via truncated normal
tau_pred_new = evaluate_forest(tau_s.X, tau_s.forest, sum_batch_axis=0)
combined_pred = mu_pred_new + tau_pred_new * basis
resid = truncated_normal_onesided(k_z, (), ~y, -combined_pred)
z = combined_pred + resid
return mu_s, tau_s, z, key, mu_pred_new, tau_pred_new
# Burnin phase
def burnin_body(carry, _):
mu_s, tau_s, z, k = carry
mu_s, tau_s, z, k, _, _ = step(mu_s, tau_s, z, k, y_bool, basis)
return (mu_s, tau_s, z, k), None
carry = (mu_state, tau_state, z, key)
carry, _ = lax.scan(burnin_body, carry, None, length=n_burnin)
# MCMC phase
def mcmc_body(carry, _):
mu_s, tau_s, z, k = carry
mu_s, tau_s, z, k, mu_pred, tau_pred = step(
mu_s, tau_s, z, k, y_bool, basis,
)
return (mu_s, tau_s, z, k), (mu_pred, tau_pred)
carry, (mu_samples, tau_samples) = lax.scan(
mcmc_body, carry, None, length=n_mcmc,
)
return mu_samples, tau_samples