pytyche.bcf.continuous

Phase 1 — independent continuous BCF (bartz-based JIT scan).

Two-forest Bayesian Causal Forest for continuous outcomes, executed on GPU via bartz’s JAX BART sampler. Maintains a prognostic forest (mu) and a treatment-effect forest (tau) updated by interleaved backfitting sweeps, with a basis-weighted error_scale on the tau forest to give Σ(r·h), Σ(h²) sufficient statistics for adaptive coding (b0=-0.5, b1=+0.5).

The MCMC loop is compiled as a single XLA program via lax.scan — no Python dispatch between iterations.

Public API

  • fit_continuous_bcf — observed-based entry point returning ContinuousBCFResult.

Internal helpers

  • _fit_continuous_bcf — raw-array private core (called by the public wrapper and by internal callers such as pytyche.bcf.hurdle.compose).

  • _sigma2_prior_params — inverse-gamma prior for σ² (df=3, scale=df).

  • _init_mu_forest — initialize prognostic forest state.

  • _init_tau_forest — initialize treatment-effect forest state with basis-weighted error scale.

  • _run_continuous_bcf_loop — host-side wrapper that prepares inputs and dispatches to the JIT’d scan.

  • _continuous_bcf_scan — JIT’d two-forest MCMC for continuous BCF.

Import graph

pytyche.bcf.continuous imports configuration types from pytyche.bcf.config and the shared preprocessing helpers _compute_basis / _preprocess_covariates from pytyche.bcf.preprocess. After the preprocess extraction at Stage C, the lazy-import dance is gone: both helpers are imported at module top and there is no circular dependency.

Functions

fit_continuous_bcf(observed, *[, ...])

Fit continuous-outcome BCF on observed experiment data.

pytyche.bcf.continuous.fit_continuous_bcf(observed, *, observed_copy='view', calibration=None, seed=0, **kwargs)[source]

Fit continuous-outcome BCF on observed experiment data.

Accepts an ObservedExperimentData and dispatches to the private raw-array core _fit_continuous_bcf.

Parameters:
  • observed (ObservedExperimentData) – Observed experiment data.

  • observed_copy (Literal['view', 'deep', 'ref']) – Copy mode for stashing the observed data on the result. 'view' (default) — zero-copy read-only view; 'deep' — independent deep copy; 'ref' — identity (no copy).

  • calibration (Calibration | None) – Optional SBC-fitted artifact applied post-fit — sugar for .apply_calibration(calibration) on the returned result (one application path; the kwarg adds no numerics). None (default) returns the raw posterior.

  • seed (int) – Random seed forwarded to GPUBCFConfig as random_seed.

  • **kwargs – Additional keyword arguments forwarded to GPUBCFConfig.

Return type:

ContinuousBCFResult

Returns:

ContinuousBCFResult with observed set to the stashed snapshot; calibrated iff calibration was supplied.

Raises:
  • ValueError – If any variant contributes zero rows (missing treatment level), or — propagated from apply_calibration — if a supplied calibration’s fitted regime does not match the observed data.

  • NotImplementedError – If K >= 3 (use fit_hurdle_bcf for multi-arm).

  • TypeError – If **kwargs contains unknown GPUBCFConfig field names.