"""Phase 1 — independent continuous BCF (bartz-based JIT scan).
Two-forest Bayesian Causal Forest for continuous outcomes, executed on GPU
via bartz's JAX BART sampler. Maintains a prognostic forest (``mu``) and a
treatment-effect forest (``tau``) updated by interleaved backfitting sweeps,
with a basis-weighted ``error_scale`` on the tau forest to give Σ(r·h),
Σ(h²) sufficient statistics for adaptive coding (b0=-0.5, b1=+0.5).
The MCMC loop is compiled as a single XLA program via ``lax.scan`` — no
Python dispatch between iterations.
Public API
----------
- ``fit_continuous_bcf`` — observed-based entry point returning ``ContinuousBCFResult``.
Internal helpers
----------------
- ``_fit_continuous_bcf`` — raw-array private core (called by the public
wrapper and by internal callers such as
``pytyche.bcf.hurdle.compose``).
- ``_sigma2_prior_params`` — inverse-gamma prior for σ² (df=3, scale=df).
- ``_init_mu_forest`` — initialize prognostic forest state.
- ``_init_tau_forest`` — initialize treatment-effect forest state with
basis-weighted error scale.
- ``_run_continuous_bcf_loop`` — host-side wrapper that prepares inputs
and dispatches to the JIT'd scan.
- ``_continuous_bcf_scan`` — JIT'd two-forest MCMC for continuous BCF.
Import graph
------------
``pytyche.bcf.continuous`` imports configuration types from
``pytyche.bcf.config`` and the shared preprocessing helpers
``_compute_basis`` / ``_preprocess_covariates`` from ``pytyche.bcf.preprocess``.
After the preprocess extraction at Stage C, the lazy-import dance is gone:
both helpers are imported at module top and there is no circular dependency.
"""
from __future__ import annotations
import dataclasses
import time
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Literal
import jax
import jax.numpy as jnp
import numpy as np
from bartz.grove import evaluate_forest
from bartz.mcmcstep import (
OutcomeType,
init,
make_p_nonterminal,
)
from jax import lax, random
from pytyche._internal.extraction import extract_fit_arrays, require_feature_columns
from pytyche._internal.observed import freeze_observed
from pytyche.bcf._bartz_compat import step_error_cov_inv, step_trees
from pytyche.bcf.config import ContinuousBCFResult, GPUBCFConfig
from pytyche.bcf.preprocess import _compute_basis, _preprocess_covariates
from pytyche.setup import _warn_if_no_cuda
if TYPE_CHECKING:
from pytyche.calibrate.artifact import Calibration
from pytyche.contracts import ObservedExperimentData
_VALID_COPY_MODES: tuple[str, ...] = ("view", "deep", "ref")
# ---------------------------------------------------------------------------
# Phase 1: Continuous BCF — public observed-based wrapper
# ---------------------------------------------------------------------------
[docs]
def fit_continuous_bcf(
observed: ObservedExperimentData,
*,
observed_copy: Literal["view", "deep", "ref"] = "view",
calibration: Calibration | None = None,
seed: int = 0,
**kwargs,
) -> ContinuousBCFResult:
"""Fit continuous-outcome BCF on observed experiment data.
Accepts an :class:`~pytyche.contracts.ObservedExperimentData` and
dispatches to the private raw-array core ``_fit_continuous_bcf``.
Args:
observed: Observed experiment data.
observed_copy: Copy mode for stashing the observed data on the result.
``'view'`` (default) — zero-copy read-only view; ``'deep'`` —
independent deep copy; ``'ref'`` — identity (no copy).
calibration: Optional SBC-fitted artifact applied post-fit —
sugar for ``.apply_calibration(calibration)`` on the returned
result (one application path; the kwarg adds no numerics).
``None`` (default) returns the raw posterior.
seed: Random seed forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig`
as ``random_seed``.
**kwargs: Additional keyword arguments forwarded to
:class:`~pytyche.bcf.config.GPUBCFConfig`.
Returns:
:class:`~pytyche.bcf.config.ContinuousBCFResult` with ``observed``
set to the stashed snapshot; calibrated iff ``calibration`` was
supplied.
Raises:
ValueError: If any variant contributes zero rows (missing treatment
level), or — propagated from ``apply_calibration`` — if a
supplied ``calibration``'s fitted regime does not match the
observed data.
NotImplementedError: If K >= 3 (use ``fit_hurdle_bcf`` for multi-arm).
TypeError: If ``**kwargs`` contains unknown
:class:`~pytyche.bcf.config.GPUBCFConfig` field names.
"""
_warn_if_no_cuda()
# --- Extract arrays from observed ---
extract = extract_fit_arrays(observed)
require_feature_columns(extract)
K = len(observed.variants)
# --- Missing-level guard (before config construction) ---
observed_levels = sorted(int(v) for v in np.unique(extract.Z.astype(int)).tolist())
expected_range = list(range(K))
if observed_levels != expected_range:
raise ValueError(
f"Every treatment level 0..{K - 1} must appear in the observed data. "
f"Observed levels: {observed_levels}; expected range: 0..{K - 1}. "
"A missing level means a variant contributed zero rows — "
"re-index or remove empty variants."
)
# --- Validate observed_copy ---
if observed_copy not in _VALID_COPY_MODES:
raise ValueError(
f"observed_copy must be one of {_VALID_COPY_MODES!r}; got {observed_copy!r}."
)
_copy_mode = observed_copy # pyright: ignore[reportAssignmentType]
copy_mode_lit: Literal["view", "deep", "ref"] = _copy_mode # type: ignore[assignment]
# --- Build config ---
config = GPUBCFConfig(random_seed=seed, **kwargs)
# --- Freeze observed snapshot ---
frozen = freeze_observed(observed, copy_mode_lit)
# --- Dispatch to private raw-array core ---
result = _fit_continuous_bcf(extract.X, extract.Z, extract.Y, extract.propensity, config)
# --- Attach stash ---
result = dataclasses.replace(result, observed=frozen)
if calibration is not None:
result = result.apply_calibration(calibration)
return result
# ---------------------------------------------------------------------------
# Phase 1: Continuous BCF — private raw-array core
# ---------------------------------------------------------------------------
def _fit_continuous_bcf(
X: np.ndarray,
Z: np.ndarray,
y: np.ndarray,
propensity: np.ndarray,
config: GPUBCFConfig,
) -> ContinuousBCFResult:
"""Fit continuous-outcome BCF on GPU via bartz.
Private raw-array core. Public surface is
:func:`fit_continuous_bcf` (observed-based wrapper).
Two-forest MCMC: mu (prognostic) + tau (treatment effect) with
basis-weighted sufficient statistics via prec_scale.
Parameters
----------
X : (n, p) covariate matrix
Z : (n,) treatment indicator (0/1)
y : (n,) continuous outcome
propensity : (n,) propensity scores (unused by bartz, reserved for interface compat)
config : GPUBCFConfig
"""
start = time.time()
K = int(np.asarray(Z).max()) + 1
if K >= 3:
raise NotImplementedError(
"_fit_continuous_bcf supports only binary treatment (Z in {0, 1}); "
f"got K={K} treatment levels. Multi-arm (K>=3) is available via "
"fit_hurdle_bcf(pooling='joint')."
)
basis = _compute_basis(Z)
# Standardize outcome
y_bar = float(y.mean())
y_std = float(y.std())
y_scaled = ((y - y_bar) / y_std).astype(np.float32)
# Preprocess covariates
X_binned, max_split, _splits = _preprocess_covariates(X, config.num_cuts)
# Initialize mu and tau forests (copy inputs — init() donates buffers)
X_mu, ms_mu = jnp.array(X_binned), jnp.array(max_split)
X_tau, ms_tau = jnp.array(X_binned), jnp.array(max_split)
mu_state = _init_mu_forest(X_mu, ms_mu, y_scaled, config)
tau_state = _init_tau_forest(X_tau, ms_tau, y_scaled, basis, config)
# Run custom two-forest MCMC loop
mu_samples, tau_samples, sigma2_samples = _run_continuous_bcf_loop(
mu_state, tau_state, y_scaled, basis, config,
)
elapsed = time.time() - start
return ContinuousBCFResult(
mu_samples=mu_samples,
tau_samples=tau_samples,
sigma2_samples=sigma2_samples,
y_bar=y_bar,
y_std=y_std,
wall_clock_seconds=elapsed,
)
# ---------------------------------------------------------------------------
# Forest initialization helpers
# ---------------------------------------------------------------------------
def _sigma2_prior_params(
y_scaled: np.ndarray,
) -> tuple[float, float]:
"""Compute inverse-gamma prior params for error variance.
Uses the BART default: place 90% prior mass below the marginal variance.
df=3, scale chosen so P(sigma² < var(y)) ≈ 0.9.
"""
df = 3.0
# With standardized y, var(y) ≈ 1.0
# scale = 2 * beta where IG(alpha, beta) with alpha = df/2
# For standardized y: scale ≈ df (places prior mean at 1.0)
scale = df
return df, scale
def _init_mu_forest(
X_binned: jax.Array,
max_split: jax.Array,
y_scaled: np.ndarray,
config: GPUBCFConfig,
):
"""Initialize prognostic (mu) forest state for continuous outcome."""
df, scale = _sigma2_prior_params(y_scaled)
return init(
X=X_binned,
y=jnp.asarray(y_scaled, dtype=jnp.float32),
outcome_type=OutcomeType.continuous,
offset=jnp.float32(0.0),
max_split=max_split,
num_trees=config.num_trees_mu,
p_nonterminal=make_p_nonterminal(config.max_depth, config.alpha_mu, config.beta_mu),
leaf_prior_cov_inv=jnp.float32(config.num_trees_mu),
error_cov_df=jnp.float32(df),
error_cov_scale=jnp.float32(scale),
)
def _init_tau_forest(
X_binned: jax.Array,
max_split: jax.Array,
y_scaled: np.ndarray,
basis: np.ndarray,
config: GPUBCFConfig,
):
"""Initialize treatment effect (tau) forest state for continuous outcome.
Uses error_scale = 1/|basis| so prec_scale = basis², enabling
basis-weighted sufficient statistics in the tree stepping.
"""
df, scale = _sigma2_prior_params(y_scaled)
error_scale = (1.0 / np.abs(basis)).astype(np.float32)
# Tau leaf prior: tighter than mu (more regularization on treatment effects)
# Prior variance = 1/(num_trees * leaf_prior_cov_inv), so larger inv = tighter
leaf_prior_inv = float(config.num_trees_tau) * 4.0
return init(
X=X_binned,
y=jnp.asarray(y_scaled, dtype=jnp.float32), # placeholder for shape
outcome_type=OutcomeType.continuous,
offset=jnp.float32(0.0),
max_split=max_split,
num_trees=config.num_trees_tau,
p_nonterminal=make_p_nonterminal(config.max_depth, config.alpha_tau, config.beta_tau),
leaf_prior_cov_inv=jnp.float32(leaf_prior_inv),
error_cov_df=jnp.float32(df),
error_cov_scale=jnp.float32(scale),
error_scale=jnp.asarray(error_scale),
)
# ---------------------------------------------------------------------------
# MCMC loops — JIT'd via lax.scan for GPU performance
# ---------------------------------------------------------------------------
def _run_continuous_bcf_loop(
mu_state,
tau_state,
y_scaled: np.ndarray,
basis: np.ndarray,
config: GPUBCFConfig,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Wrapper that dispatches to JIT'd continuous BCF loop."""
y_jax = jnp.asarray(y_scaled, dtype=jnp.float32)
basis_jax = jnp.asarray(basis, dtype=jnp.float32)
key = random.key(config.random_seed)
mu_samples, tau_samples, sigma2_samples = _continuous_bcf_scan(
mu_state, tau_state, y_jax, basis_jax, key,
config.num_burnin, config.num_mcmc,
)
# Convert to numpy (n_mcmc, n) → (n, n_mcmc)
return np.array(mu_samples).T, np.array(tau_samples).T, np.array(sigma2_samples)
@partial(jax.jit, static_argnums=(5, 6), donate_argnums=(0, 1))
def _continuous_bcf_scan(mu_state, tau_state, y, basis, key, n_burnin, n_mcmc):
"""JIT'd two-forest MCMC for continuous BCF via lax.scan.
Compiles the entire loop as a single XLA program — no Python dispatch
overhead between iterations. This is the same pattern bartz's run_mcmc
uses internally (lax.while_loop over mcmcstep.step).
"""
sigma2_inv = jnp.float32(1.0)
def step(mu_s, tau_s, sigma2_inv, key, y, basis):
key, k_mu, k_tau, k_sig = random.split(key, 4)
# step_trees does backfitting: for each tree, adds back old prediction
# then subtracts new. So resid must be partial_y - current_forest_sum.
mu_pred = evaluate_forest(mu_s.X, mu_s.forest, sum_batch_axis=0)
tau_pred = evaluate_forest(tau_s.X, tau_s.forest, sum_batch_axis=0)
# 1. Mu sweep: partial_y = y - tau*basis; resid = partial_y - mu_sum
mu_resid = y - tau_pred * basis - mu_pred
mu_s = replace(mu_s, resid=mu_resid, error_cov_inv=sigma2_inv)
mu_s = step_trees(k_mu, mu_s)
# 2. Tau sweep: partial_y = (y - mu_new)/basis; resid = partial_y - tau_sum
mu_pred_new = evaluate_forest(mu_s.X, mu_s.forest, sum_batch_axis=0)
tau_resid = (y - mu_pred_new) / basis - tau_pred
tau_s = replace(tau_s, resid=tau_resid, error_cov_inv=sigma2_inv)
tau_s = step_trees(k_tau, tau_s)
# 3. Sample sigma² from combined residuals
tau_pred_new = evaluate_forest(tau_s.X, tau_s.forest, sum_batch_axis=0)
combined_resid = y - mu_pred_new - tau_pred_new * basis
mu_s = replace(mu_s, resid=combined_resid)
mu_s = step_error_cov_inv(k_sig, mu_s)
sigma2_inv = mu_s.error_cov_inv
return mu_s, tau_s, sigma2_inv, key, mu_pred_new, tau_pred_new
# Burnin phase (no output saved)
def burnin_body(carry, _):
mu_s, tau_s, s2inv, k = carry
mu_s, tau_s, s2inv, k, _, _ = step(mu_s, tau_s, s2inv, k, y, basis)
return (mu_s, tau_s, s2inv, k), None
carry = (mu_state, tau_state, sigma2_inv, key)
carry, _ = lax.scan(burnin_body, carry, None, length=n_burnin)
# MCMC phase (save samples)
def mcmc_body(carry, _):
mu_s, tau_s, s2inv, k = carry
mu_s, tau_s, s2inv, k, mu_pred, tau_pred = step(
mu_s, tau_s, s2inv, k, y, basis,
)
sigma2 = jnp.float32(1.0) / s2inv
return (mu_s, tau_s, s2inv, k), (mu_pred, tau_pred, sigma2)
carry, (mu_samples, tau_samples, sigma2_samples) = lax.scan(
mcmc_body, carry, None, length=n_mcmc,
)
# mu_samples: (n_mcmc, n), tau_samples: (n_mcmc, n), sigma2_samples: (n_mcmc,)
return mu_samples, tau_samples, sigma2_samples