Source code for pytyche.bcf.continuous

"""Phase 1 — independent continuous BCF (bartz-based JIT scan).

Two-forest Bayesian Causal Forest for continuous outcomes, executed on GPU
via bartz's JAX BART sampler. Maintains a prognostic forest (``mu``) and a
treatment-effect forest (``tau``) updated by interleaved backfitting sweeps,
with a basis-weighted ``error_scale`` on the tau forest to give Σ(r·h),
Σ(h²) sufficient statistics for adaptive coding (b0=-0.5, b1=+0.5).

The MCMC loop is compiled as a single XLA program via ``lax.scan`` — no
Python dispatch between iterations.

Public API
----------
- ``fit_continuous_bcf`` — observed-based entry point returning ``ContinuousBCFResult``.

Internal helpers
----------------
- ``_fit_continuous_bcf`` — raw-array private core (called by the public
  wrapper and by internal callers such as
  ``pytyche.bcf.hurdle.compose``).
- ``_sigma2_prior_params`` — inverse-gamma prior for σ² (df=3, scale=df).
- ``_init_mu_forest`` — initialize prognostic forest state.
- ``_init_tau_forest`` — initialize treatment-effect forest state with
  basis-weighted error scale.
- ``_run_continuous_bcf_loop`` — host-side wrapper that prepares inputs
  and dispatches to the JIT'd scan.
- ``_continuous_bcf_scan`` — JIT'd two-forest MCMC for continuous BCF.

Import graph
------------
``pytyche.bcf.continuous`` imports configuration types from
``pytyche.bcf.config`` and the shared preprocessing helpers
``_compute_basis`` / ``_preprocess_covariates`` from ``pytyche.bcf.preprocess``.
After the preprocess extraction at Stage C, the lazy-import dance is gone:
both helpers are imported at module top and there is no circular dependency.
"""

from __future__ import annotations

import dataclasses
import time
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Literal

import jax
import jax.numpy as jnp
import numpy as np
from bartz.grove import evaluate_forest
from bartz.mcmcstep import (
    OutcomeType,
    init,
    make_p_nonterminal,
)
from jax import lax, random

from pytyche._internal.extraction import extract_fit_arrays, require_feature_columns
from pytyche._internal.observed import freeze_observed
from pytyche.bcf._bartz_compat import step_error_cov_inv, step_trees
from pytyche.bcf.config import ContinuousBCFResult, GPUBCFConfig
from pytyche.bcf.preprocess import _compute_basis, _preprocess_covariates
from pytyche.setup import _warn_if_no_cuda

if TYPE_CHECKING:
    from pytyche.calibrate.artifact import Calibration
    from pytyche.contracts import ObservedExperimentData

_VALID_COPY_MODES: tuple[str, ...] = ("view", "deep", "ref")

# ---------------------------------------------------------------------------
# Phase 1: Continuous BCF — public observed-based wrapper
# ---------------------------------------------------------------------------


[docs] def fit_continuous_bcf( observed: ObservedExperimentData, *, observed_copy: Literal["view", "deep", "ref"] = "view", calibration: Calibration | None = None, seed: int = 0, **kwargs, ) -> ContinuousBCFResult: """Fit continuous-outcome BCF on observed experiment data. Accepts an :class:`~pytyche.contracts.ObservedExperimentData` and dispatches to the private raw-array core ``_fit_continuous_bcf``. Args: observed: Observed experiment data. observed_copy: Copy mode for stashing the observed data on the result. ``'view'`` (default) — zero-copy read-only view; ``'deep'`` — independent deep copy; ``'ref'`` — identity (no copy). calibration: Optional SBC-fitted artifact applied post-fit — sugar for ``.apply_calibration(calibration)`` on the returned result (one application path; the kwarg adds no numerics). ``None`` (default) returns the raw posterior. seed: Random seed forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig` as ``random_seed``. **kwargs: Additional keyword arguments forwarded to :class:`~pytyche.bcf.config.GPUBCFConfig`. Returns: :class:`~pytyche.bcf.config.ContinuousBCFResult` with ``observed`` set to the stashed snapshot; calibrated iff ``calibration`` was supplied. Raises: ValueError: If any variant contributes zero rows (missing treatment level), or — propagated from ``apply_calibration`` — if a supplied ``calibration``'s fitted regime does not match the observed data. NotImplementedError: If K >= 3 (use ``fit_hurdle_bcf`` for multi-arm). TypeError: If ``**kwargs`` contains unknown :class:`~pytyche.bcf.config.GPUBCFConfig` field names. """ _warn_if_no_cuda() # --- Extract arrays from observed --- extract = extract_fit_arrays(observed) require_feature_columns(extract) K = len(observed.variants) # --- Missing-level guard (before config construction) --- observed_levels = sorted(int(v) for v in np.unique(extract.Z.astype(int)).tolist()) expected_range = list(range(K)) if observed_levels != expected_range: raise ValueError( f"Every treatment level 0..{K - 1} must appear in the observed data. " f"Observed levels: {observed_levels}; expected range: 0..{K - 1}. " "A missing level means a variant contributed zero rows — " "re-index or remove empty variants." ) # --- Validate observed_copy --- if observed_copy not in _VALID_COPY_MODES: raise ValueError( f"observed_copy must be one of {_VALID_COPY_MODES!r}; got {observed_copy!r}." ) _copy_mode = observed_copy # pyright: ignore[reportAssignmentType] copy_mode_lit: Literal["view", "deep", "ref"] = _copy_mode # type: ignore[assignment] # --- Build config --- config = GPUBCFConfig(random_seed=seed, **kwargs) # --- Freeze observed snapshot --- frozen = freeze_observed(observed, copy_mode_lit) # --- Dispatch to private raw-array core --- result = _fit_continuous_bcf(extract.X, extract.Z, extract.Y, extract.propensity, config) # --- Attach stash --- result = dataclasses.replace(result, observed=frozen) if calibration is not None: result = result.apply_calibration(calibration) return result
# --------------------------------------------------------------------------- # Phase 1: Continuous BCF — private raw-array core # --------------------------------------------------------------------------- def _fit_continuous_bcf( X: np.ndarray, Z: np.ndarray, y: np.ndarray, propensity: np.ndarray, config: GPUBCFConfig, ) -> ContinuousBCFResult: """Fit continuous-outcome BCF on GPU via bartz. Private raw-array core. Public surface is :func:`fit_continuous_bcf` (observed-based wrapper). Two-forest MCMC: mu (prognostic) + tau (treatment effect) with basis-weighted sufficient statistics via prec_scale. Parameters ---------- X : (n, p) covariate matrix Z : (n,) treatment indicator (0/1) y : (n,) continuous outcome propensity : (n,) propensity scores (unused by bartz, reserved for interface compat) config : GPUBCFConfig """ start = time.time() K = int(np.asarray(Z).max()) + 1 if K >= 3: raise NotImplementedError( "_fit_continuous_bcf supports only binary treatment (Z in {0, 1}); " f"got K={K} treatment levels. Multi-arm (K>=3) is available via " "fit_hurdle_bcf(pooling='joint')." ) basis = _compute_basis(Z) # Standardize outcome y_bar = float(y.mean()) y_std = float(y.std()) y_scaled = ((y - y_bar) / y_std).astype(np.float32) # Preprocess covariates X_binned, max_split, _splits = _preprocess_covariates(X, config.num_cuts) # Initialize mu and tau forests (copy inputs — init() donates buffers) X_mu, ms_mu = jnp.array(X_binned), jnp.array(max_split) X_tau, ms_tau = jnp.array(X_binned), jnp.array(max_split) mu_state = _init_mu_forest(X_mu, ms_mu, y_scaled, config) tau_state = _init_tau_forest(X_tau, ms_tau, y_scaled, basis, config) # Run custom two-forest MCMC loop mu_samples, tau_samples, sigma2_samples = _run_continuous_bcf_loop( mu_state, tau_state, y_scaled, basis, config, ) elapsed = time.time() - start return ContinuousBCFResult( mu_samples=mu_samples, tau_samples=tau_samples, sigma2_samples=sigma2_samples, y_bar=y_bar, y_std=y_std, wall_clock_seconds=elapsed, ) # --------------------------------------------------------------------------- # Forest initialization helpers # --------------------------------------------------------------------------- def _sigma2_prior_params( y_scaled: np.ndarray, ) -> tuple[float, float]: """Compute inverse-gamma prior params for error variance. Uses the BART default: place 90% prior mass below the marginal variance. df=3, scale chosen so P(sigma² < var(y)) ≈ 0.9. """ df = 3.0 # With standardized y, var(y) ≈ 1.0 # scale = 2 * beta where IG(alpha, beta) with alpha = df/2 # For standardized y: scale ≈ df (places prior mean at 1.0) scale = df return df, scale def _init_mu_forest( X_binned: jax.Array, max_split: jax.Array, y_scaled: np.ndarray, config: GPUBCFConfig, ): """Initialize prognostic (mu) forest state for continuous outcome.""" df, scale = _sigma2_prior_params(y_scaled) return init( X=X_binned, y=jnp.asarray(y_scaled, dtype=jnp.float32), outcome_type=OutcomeType.continuous, offset=jnp.float32(0.0), max_split=max_split, num_trees=config.num_trees_mu, p_nonterminal=make_p_nonterminal(config.max_depth, config.alpha_mu, config.beta_mu), leaf_prior_cov_inv=jnp.float32(config.num_trees_mu), error_cov_df=jnp.float32(df), error_cov_scale=jnp.float32(scale), ) def _init_tau_forest( X_binned: jax.Array, max_split: jax.Array, y_scaled: np.ndarray, basis: np.ndarray, config: GPUBCFConfig, ): """Initialize treatment effect (tau) forest state for continuous outcome. Uses error_scale = 1/|basis| so prec_scale = basis², enabling basis-weighted sufficient statistics in the tree stepping. """ df, scale = _sigma2_prior_params(y_scaled) error_scale = (1.0 / np.abs(basis)).astype(np.float32) # Tau leaf prior: tighter than mu (more regularization on treatment effects) # Prior variance = 1/(num_trees * leaf_prior_cov_inv), so larger inv = tighter leaf_prior_inv = float(config.num_trees_tau) * 4.0 return init( X=X_binned, y=jnp.asarray(y_scaled, dtype=jnp.float32), # placeholder for shape outcome_type=OutcomeType.continuous, offset=jnp.float32(0.0), max_split=max_split, num_trees=config.num_trees_tau, p_nonterminal=make_p_nonterminal(config.max_depth, config.alpha_tau, config.beta_tau), leaf_prior_cov_inv=jnp.float32(leaf_prior_inv), error_cov_df=jnp.float32(df), error_cov_scale=jnp.float32(scale), error_scale=jnp.asarray(error_scale), ) # --------------------------------------------------------------------------- # MCMC loops — JIT'd via lax.scan for GPU performance # --------------------------------------------------------------------------- def _run_continuous_bcf_loop( mu_state, tau_state, y_scaled: np.ndarray, basis: np.ndarray, config: GPUBCFConfig, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Wrapper that dispatches to JIT'd continuous BCF loop.""" y_jax = jnp.asarray(y_scaled, dtype=jnp.float32) basis_jax = jnp.asarray(basis, dtype=jnp.float32) key = random.key(config.random_seed) mu_samples, tau_samples, sigma2_samples = _continuous_bcf_scan( mu_state, tau_state, y_jax, basis_jax, key, config.num_burnin, config.num_mcmc, ) # Convert to numpy (n_mcmc, n) → (n, n_mcmc) return np.array(mu_samples).T, np.array(tau_samples).T, np.array(sigma2_samples) @partial(jax.jit, static_argnums=(5, 6), donate_argnums=(0, 1)) def _continuous_bcf_scan(mu_state, tau_state, y, basis, key, n_burnin, n_mcmc): """JIT'd two-forest MCMC for continuous BCF via lax.scan. Compiles the entire loop as a single XLA program — no Python dispatch overhead between iterations. This is the same pattern bartz's run_mcmc uses internally (lax.while_loop over mcmcstep.step). """ sigma2_inv = jnp.float32(1.0) def step(mu_s, tau_s, sigma2_inv, key, y, basis): key, k_mu, k_tau, k_sig = random.split(key, 4) # step_trees does backfitting: for each tree, adds back old prediction # then subtracts new. So resid must be partial_y - current_forest_sum. mu_pred = evaluate_forest(mu_s.X, mu_s.forest, sum_batch_axis=0) tau_pred = evaluate_forest(tau_s.X, tau_s.forest, sum_batch_axis=0) # 1. Mu sweep: partial_y = y - tau*basis; resid = partial_y - mu_sum mu_resid = y - tau_pred * basis - mu_pred mu_s = replace(mu_s, resid=mu_resid, error_cov_inv=sigma2_inv) mu_s = step_trees(k_mu, mu_s) # 2. Tau sweep: partial_y = (y - mu_new)/basis; resid = partial_y - tau_sum mu_pred_new = evaluate_forest(mu_s.X, mu_s.forest, sum_batch_axis=0) tau_resid = (y - mu_pred_new) / basis - tau_pred tau_s = replace(tau_s, resid=tau_resid, error_cov_inv=sigma2_inv) tau_s = step_trees(k_tau, tau_s) # 3. Sample sigma² from combined residuals tau_pred_new = evaluate_forest(tau_s.X, tau_s.forest, sum_batch_axis=0) combined_resid = y - mu_pred_new - tau_pred_new * basis mu_s = replace(mu_s, resid=combined_resid) mu_s = step_error_cov_inv(k_sig, mu_s) sigma2_inv = mu_s.error_cov_inv return mu_s, tau_s, sigma2_inv, key, mu_pred_new, tau_pred_new # Burnin phase (no output saved) def burnin_body(carry, _): mu_s, tau_s, s2inv, k = carry mu_s, tau_s, s2inv, k, _, _ = step(mu_s, tau_s, s2inv, k, y, basis) return (mu_s, tau_s, s2inv, k), None carry = (mu_state, tau_state, sigma2_inv, key) carry, _ = lax.scan(burnin_body, carry, None, length=n_burnin) # MCMC phase (save samples) def mcmc_body(carry, _): mu_s, tau_s, s2inv, k = carry mu_s, tau_s, s2inv, k, mu_pred, tau_pred = step( mu_s, tau_s, s2inv, k, y, basis, ) sigma2 = jnp.float32(1.0) / s2inv return (mu_s, tau_s, s2inv, k), (mu_pred, tau_pred, sigma2) carry, (mu_samples, tau_samples, sigma2_samples) = lax.scan( mcmc_body, carry, None, length=n_mcmc, ) # mu_samples: (n_mcmc, n), tau_samples: (n_mcmc, n), sigma2_samples: (n_mcmc,) return mu_samples, tau_samples, sigma2_samples