--- title: Multi-arm hurdle BCF review-state: drafting last-human-review: "2026-06-11" depends-on: - src/pytyche/bcf/hurdle/__init__.py - src/pytyche/bcf/hurdle/model.py - src/pytyche/bcf/hurdle/compose.py - src/pytyche/bcf/preprocess.py owner: tradcliffe quadrant: concept --- # Multi-arm hurdle BCF Pytyche's joint hurdle estimator (`fit_hurdle_bcf(pooling="joint")`) began as a binary-arm estimator (one treatment vs. one control). This page explains how the estimator generalizes to K = 3..10 arms, why the generalization takes the form it does, and what the resulting result shapes look like. > **Naming note.** The joint estimator is `fit_hurdle_bcf(pooling="joint")` > returning `HurdleBCFResult`. Earlier pre-release names > (`fit_joint_hurdle_bcf`, `JointGPUBCFResult`) are retired; the old > fit name raises an `ImportError` pointing here. ## How a multi-arm fit is structured A K-arm fit uses a **single shared prognostic forest** (the μ forest, common to every arm) and **one treatment forest** whose leaves carry a `(K−1)` vector of contrast-basis coefficients. Every posterior draw produces a *joint* sample over all K−1 treatment-vs-control contrasts for each visitor: the per-arm outcomes are evaluated from a single multivariate leaf draw and differenced against control, so the contrasts stay correlated rather than being estimated in isolation. That joint structure is the defining property of the estimator, and it is what makes the multi-arm posterior usable for adaptive targeting. Three things follow from it. ### Cross-arm prognostic sharing The conversion and severity baseline (the μ forest) is fit once, across all arms. A visitor who received one treatment still informs the baseline used for every other arm, because the prognostic structure is shared; the treatment forest is left to capture only the *differences* from that baseline. This is a real efficiency gain in the regime adaptive experiments live in, where per-arm sample sizes are small. (Fitting each treatment-vs-control contrast as a separate model would relearn the baseline K−1 times from disjoint slices of the data.) ### A selection-bias-free "best arm" Choosing the best arm means taking an argmax across the K−1 contrasts. Computed within each correlated joint draw, that argmax already accounts for the noise in the comparison, so the quantity the recommendation engine reads — `P(arm_k best | segment)` — is calibrated. Estimating each contrast in isolation and then taking the max instead would reintroduce the classic winner's-curse over-estimation of the selected arm, an effect that grows with K. ### Correct joint uncertainty Every contrast is measured against the same control arm, so the K−1 contrasts share structure. The joint posterior carries that correlation, which is what keeps credible-interval widths and Thompson allocation probabilities honest; treating the contrasts as independent would understate the shared uncertainty. ## K = 2 dispatch: binary-arm path stays intact The K = 2 binary-arm path is the validated reference implementation with pinned fingerprint tests. Rather than unifying the code paths, the entry point dispatches on K: ``` K = int(Z.max()) + 1 if K == 2: else: ``` This makes binary-arm fingerprint preservation a structural guarantee: K = 2 inputs never reach the multivariate code. Downstream code that passes binary treatment vectors sees no change. The unification of both paths under the multivariate body is deferred to v0.3+, when the multi-arm path has been the production path long enough to become the reference implementation itself. ## Contrast coding For K ≥ 3, `_compute_basis(Z)` returns an `(n, K−1)` matrix with **treatment-vs-control reference coding** using the same ±0.5 magnitude as the binary case: - Column j is `+0.5` if `Z = j+1` (treatment arm j+1). - Column j is `−0.5` if `Z = 0` (control). - Column j is `0` otherwise. Control is the baseline, and it loads on **every** contrast column: a control visitor's predicted response is `μ − 0.5·Σ_j τ_j`, while a treatment-j+1 visitor's is `μ + 0.5·τ_j`. Two consequences worth being precise about: - The **leaf component** `τ_j` is *not* itself the treatment-j+1-vs-control effect (except at K = 2, where the coding collapses to the binary case). The model-implied contrast is `0.5·τ_j + 0.5·Σ_k τ_k` — the leaf coordinates are coupled. - The **composed outputs are genuine contrasts**: the composition kernel evaluates each arm's potential outcome from the same draw (control at `μ − 0.5·Στ`, arm k at `μ + 0.5·τ_{k−1}`) and differences them, so `rpv_cate_samples[:, :, j]` is exactly the treatment-j+1-vs-control RPV CATE — what the downstream recommendation engine (Thompson allocation, graduation-candidate detection) operates on. The prognostic forest (μ) keeps a scalar `ones(n)` basis at every K; at K ≥ 3 the tau forest contributes the `−0.5·Στ` term to control predictions. ## The multivariate Normal-Normal conjugate at Stage 1b The tau forest has two leaf draws per tree — one per channel: - **Alpha channel (conversion / probit):** error scale `σ²_conv = 1.0`, per-tree prior precision `τ_conv = 1 / num_trees_tau`. - **Beta channel (severity / log-revenue NN):** error scale `σ²_sev = 1 / τ_0`, per-tree prior precision `τ_sev = 1 / num_trees_tau`. At K ≥ 3, each leaf draw generalizes from scalar to a **(K−1)-vector multivariate Normal-Normal conjugate**. Per leaf, per channel: - `sum_hr` becomes a `(K−1,)` vector: `Σ basis_visitor^T · residual_visitor` over visitors assigned to the leaf. - `sum_h2` becomes a `(K−1, K−1)` matrix: `Σ basis_visitor^T · basis_visitor` over the relevant visitors (all for conversion; converters only for severity). - **Posterior mean:** `P^{-1} @ (sum_hr / σ²)` where `P = sum_h2 / σ² + (1 / τ_chan) · I`. - **Posterior sample:** Cholesky solve of P, then a correlated multivariate-normal draw via `jnp.linalg`. The `σ²_sev` scaling on the severity `sum_h2` matrix must be carried through — dropping it (treating both channels as `σ² = 1`) is a silent miscalibration bug. The prior precision is fixed and diagonal — `(1/τ_chan) · I` per channel — on the leaf **components**. Because the contrasts mix the components (control loads on every column), the implied prior on the treatment-vs-control contrasts grows with K and is positively correlated across contrasts: per-contrast prior variance `0.25·(K+2)·τ_chan` (equal to the binary-arm prior at K = 2; 3× it at K = 10) with prior correlation `(K+1)/(K+2)`. This is a shared-effect shrinkage structure — contrasts pool toward each other a priori — which intensifies at large K. The prior is not sampled; there is no Inverse-Wishart hyperprior on the leaf covariance in v0.2. The **mu (prognostic) forest stays scalar at every K.** Only the tau forest's `(n, K−1)` basis triggers the multivariate leaf body. The dispatch is on `basis.ndim`: scalar for the mu forest, matrix for the tau forest. ### Why the binary fast-path shortcut does not survive The binary-arm fast path is efficient because every visitor's basis is `±0.5`, so `basis² = 0.25` is a constant and each leaf's precision reduces to `count × 0.25` — a value that depends only on how many visitors land in the leaf, never on which arm they came from. This lets the structure-only factors be precomputed in parallel from counts alone. With treatment-vs-control coding, a control visitor contributes `0.25` to every entry of the `(K−1, K−1)` matrix; a treatment-j visitor contributes `0.25` to a single diagonal entry. Per-leaf `sum_h2` is a genuine matrix that depends on the arm composition of the leaf, not a count. The count-based shortcut does not apply. v0.2 ships the correctness-first path: the tau-forest multivariate leaf draw runs inside the `lax.scan`, computing the per-leaf Cholesky solve + correlated draw without hoisting structure-only factors to the parallel stage. Hoisting is a **profile-driven follow-up** gated on the K = 10 performance characterization (marked `# perf-follow-up` in the code). ## Severity precision stays a single scalar The global severity precision `τ_0` is the precision of the **scalar** severity residual per visitor. Each visitor is assigned to exactly one arm and observes exactly one log-revenue value, so its severity residual `y_log − μ_sev − (τ_sev · basis)` is a per-visitor dot product `(K−1) · (K−1) → scalar` — there is no `(K−1)`-vector of severity residuals, because potential outcomes are not jointly observed. So `τ_0` stays a scalar Gamma draw at every K: `tau0_samples` is `(S_total,)` and `sigma2_samples` is its reciprocal property, with no matrix precision, no Wishart sampler, and no positive-definite-matrix invariant. The genuine `(K−1, K−1)` coupling in the model lives one level up — in the Stage-1b *leaf-vector* posterior covariance, which is full because control visitors load on every contrast column — not in the severity precision. ## The compose-kernel seam and selection-bias-free argmax After the MCMC chains run, `_compose_rpv_cate_jax` in `compose.py` assembles the final result arrays. The key output for the recommendation engine is: ``` rpv_cate_samples (n, S_total, K−1) ``` This is the **jointly-sampled** contrast posterior for every visitor across all posterior draws (chains concatenated, `S_total = (num_mcmc / thin_factor) · num_chains`). The K−1 contrast values per draw per visitor are correlated — they came from a single multivariate leaf draw — not independent. The recommendation engine reads `P(arm_k is best | segment)` from this posterior. Because the K−1 values are treatment-vs-control *contrasts*, the best arm over all K arms (control included) must be resolved per draw with control handled explicitly: **control wins a draw when every contrast is negative; otherwise the treatment with the largest positive contrast wins.** A bare `argmax` over the K−1 contrasts would never select control and would answer "best *treatment*," not "best *arm*." Counting the correct rule across draws gives the per-visitor probability each arm is best: ```python # rpv_cate_samples: (n, S_total, K-1) — contrasts vs control best_contrast = jnp.argmax(rpv_cate_samples, axis=-1) # (n, S_total) any_treatment_wins = jnp.any(rpv_cate_samples > 0, axis=-1) # (n, S_total) # P(control best): all K-1 contrasts negative in the draw p_control_best = jnp.mean(~any_treatment_wins, axis=1) # (n,) # P(treatment j+1 best): it is the argmax contrast AND beats control p_treatment_best = jnp.stack([ jnp.mean(any_treatment_wins & (best_contrast == j), axis=1) for j in range(K - 1) ], axis=-1) # (n, K-1) ``` The argmax is computed *within each joint draw* (the K−1 contrasts came from a single multivariate leaf draw and are correlated), not across independent posteriors — so the winner's-curse bias (problem (b) above) is eliminated. That joint structure is what makes `P(arm_k best)` calibrated. ## Result shapes Sample arrays follow the visitor-major, chains-concatenated convention at every K. At K ≥ 3, `HurdleBCFResult` carries: ``` rpv_cate_samples (n, S_total, K−1) float32, CPU — always populated; the jointly-sampled contrast vector p_samples (n, S_total, K) jax.Array — None unless retain_channel_samples=True sev_samples (n, S_total, K) jax.Array — None unless retain_channel_samples=True tau0_samples (S_total,) float32 — scalar at every K sigma2_samples property: 1 / tau0_samples (scalar at every K) ``` At K = 2, the legacy field names are populated instead: `p0_samples`, `p1_samples`, `sev0_samples`, `sev1_samples` (all `(n, S_total)`), and `rpv_cate_samples (n, S_total)` (scalar contrast). The new per-arm fields are `None` at K = 2. The legacy paired fields are `None` at K ≥ 3. `rpv_cate_samples` is always populated, with a trailing arm axis added at K ≥ 3. `tau0_samples` and `sigma2_samples` are **shape-invariant in K**: scalar `(S_total,)` precision and its reciprocal property at every K. ## K = 3 worked example This is the **executable acceptance target** for the K-arm path: a complete three-arm fit (control + two treatment variants) whose assertions encode the K = 3 contract. It is registered `xfail(strict)` in the docs test suite (`tests/test_docs/test_executable_examples.py`) until the K ≥ 3 path lands. The shape assertions fail against today's binary-collapse behavior — a K = 3 `Z` currently routes to the binary path, so `rpv_cate_samples` comes back `(n, S)` rather than `(n, S, 2)` — and flip the suite green the moment the multi-arm path is implemented, at which point the `xfail` entry is removed. The example *is* the contract the implementation must satisfy. ```{testcode} import numpy as np import pandas as pd from pytyche import fit_hurdle_bcf from pytyche.contracts import ObservedExperimentData, VariantData rng = np.random.default_rng(42) n = 300 # Three arms: 0 = control, 1 = free_shipping, 2 = promo_10pct. # K is the number of variants in the observed data — no K argument is passed. Z = rng.integers(0, 3, size=n) X = rng.standard_normal((n, 5)).astype(np.float32) # Zero-inflated revenue: arm-specific conversion + log-normal severity. logit = -0.5 + 0.4 * X[:, 0] + np.array([0.0, 0.5, 1.0])[Z] # treatments lift conversion p_convert = 1.0 / (1.0 + np.exp(-logit)) converted = rng.random(n) < p_convert log_rev = np.array([3.4, 3.5, 3.6])[Z] + 0.3 * rng.standard_normal(n) # + severity Y = np.where(converted, np.exp(log_rev), 0.0).astype(np.float32) # Package the arrays as observed experiment data: one VariantData per arm, # visitor frames carrying the schema columns plus the five features. # (Assignment is uniform, so the adapter's constant 1/K multi-arm # propensity placeholder matches the design.) arm_names = ("control", "free_shipping", "promo_10pct") variants = [] for code, name in enumerate(arm_names): mask = Z == code m = int(mask.sum()) frame = pd.DataFrame({ "visitor_id": [f"{name}_v{i}" for i in range(m)], "experiment_id": "multi-arm-demo", "variant": name, "converted": converted[mask], "revenue": Y[mask].astype(np.float64), "orders_count": converted[mask].astype(np.int64), "sessions_count": np.ones(m, dtype=np.int64), }) for j in range(X.shape[1]): frame[f"f{j}"] = X[mask, j].astype(np.float64) variants.append(VariantData( name=name, visitors=frame, n_visitors=m, n_conversions=int(converted[mask].sum()), total_revenue=float(Y[mask].sum()), )) observed = ObservedExperimentData( experiment_id="multi-arm-demo", metric="revenue_per_visitor", variants=variants, ) result = fit_hurdle_bcf( observed, num_trees_mu=20, num_trees_tau=10, num_gfr_sweeps=2, num_burnin=20, num_mcmc=30, retain_channel_samples=True, ) S = result.rpv_cate_samples.shape[1] # S_total = (num_mcmc / thin) * num_chains # --- The K = 3 contract (these assertions are the acceptance target) --- # rpv_cate_samples carries a trailing (K-1) = 2 contrast axis at K = 3. assert result.rpv_cate_samples.shape == (n, S, 2) # Per-arm channel samples carry a trailing K = 3 arm axis. assert result.p_samples.shape == (n, S, 3) assert result.sev_samples.shape == (n, S, 3) # Severity precision stays scalar at every K. assert result.tau0_samples.shape == (S,) p = np.asarray(result.p_samples) sev = np.asarray(result.sev_samples) rpv = np.asarray(result.rpv_cate_samples) # Conversion probabilities live in [0, 1]. assert p.min() >= 0.0 and p.max() <= 1.0 # RPV decomposition identity, per contrast j, per draw (visitor-major): # rpv[:, :, j] == p[:, :, j+1] * sev[:, :, j+1] - p[:, :, 0] * sev[:, :, 0] for j in range(2): recon = p[:, :, j + 1] * sev[:, :, j + 1] - p[:, :, 0] * sev[:, :, 0] assert np.allclose(rpv[:, :, j], recon, atol=1e-3) # P(arm_k best | visitor), control included: control wins a draw iff every # contrast is negative; otherwise the largest positive contrast wins. best_contrast = np.argmax(rpv, axis=-1) # (n, S) any_treatment_wins = np.any(rpv > 0, axis=-1) # (n, S) p_control_best = np.mean(~any_treatment_wins, axis=1) # (n,) p_treatment_best = np.stack( [np.mean(any_treatment_wins & (best_contrast == j), axis=1) for j in range(2)], axis=-1, ) # (n, 2) assert p_control_best.shape == (n,) assert p_treatment_best.shape == (n, 2) # Each visitor's arm-best probabilities (control + 2 treatments) sum to 1. assert np.allclose(p_control_best + p_treatment_best.sum(axis=1), 1.0) ``` ### Why five segments are often enough A core empirical finding supporting the segment-based approach: Zhang & Misra 2022 ("Coarse Personalization," arXiv 2204.05793, EC '24) found that in a food-delivery promotions RCT, **five discrete segments recover 99.5% of expected incremental profits** relative to full personalization. The policy tree is a near-lossless compression of the CATE surface. This result was obtained with continuous outcomes and a single experiment; transfer to zero-inflated multi-arm settings has not been directly studied, but it informs the coarse-segment framing of pytyche's recommendation engine (the `sequential-experiment-api` change). ## GPU GFR warm-start `fit_hurdle_bcf(pooling="joint")` runs five GPU GFR warm-start sweeps before MCMC by default (`num_gfr_sweeps=5`, `gfr_backend="gpu"`). The GPU GFR path is the production warm-start — skipping it (`num_gfr_sweeps=0`) causes a real degradation in posterior quality, not a graceful fallback. The GPU GFR tau sweep receives the same (K−1)-vector generalization as the MCMC Stage-1b path: the same per-leaf multivariate Normal-Normal conjugate, the same Cholesky solve + correlated draw, the same fixed diagonal prior. It reuses the shared `_log_marginal_leaves` and `_hurdle_sufficient_stats*` helpers. The StochTree CPU GFR backend (`gfr_backend="cpu"`, `cpu_gfr.py`) stays at K = 2 only. Passing K ≥ 3 with `gfr_backend="cpu"` raises `NotImplementedError`. The production `gfr_backend="gpu"` path carries multi-arm. ## Performance characteristics Measured on a Quadro RTX 5000 (16 GB) — a rough characterization pass (300 burn-in + 300 MCMC × 2 chains, 50 + 50 trees, 5 GFR sweeps, wall-clock including one-time XLA compile): | K | n | wall-clock | vs. K = 2 | | --- | --- | --- | --- | | 2 (binary fast-path) | 10 000 | 62.4 s | — | | 3 (multivariate) | 10 000 | 78.7 s | **1.26×** | | 10 (multivariate) | 5 000 | 81.4 s | 1.30× at half the `n` (≈ 2.6× per visitor) | - **K = 2:** zero performance change (dispatch routes to the unchanged binary-arm code). - **K = 3:** **1.26×** the K = 2 wall-clock — *better* than the original ≈1.5× estimate. The `(K−1, K−1) = (2, 2)` matrix Cholesky inside the `lax.scan` is cheap at this contrast count. - **K = 10:** ≈ 2.6× per-visitor cost (the `(9, 9)` per-leaf Cholesky), but an 81 s 10-arm fit is comfortably tractable. **The in-scan per-leaf factorization does not dominate at this scale**, so the performance follow-up (hoist the structure-only `sum_h2` matrices to the parallel stage; the code marks the site `# perf-follow-up`) is *not* triggered by this evidence — file it only if a larger-scale K = 10 run shows it dominating. The numbers above are a rough pass; a production-scale benchmark (`n = 50 000`, `num_chains = 4`, `num_mcmc = 2000`) with GPU-memory capture is deferred. The raw artifact lives at `bench/multi_arm_baseline.json`. ## Related concepts - {doc}`overview` — what pytyche does and who it's for - {doc}`result-objects` — the full result-object hierarchy (`HurdleBCFResult`, `AnalysisResult`, `Experiment`) - {doc}`sequential-targeting` — the segment-based enrichment loop that consumes the multi-arm posterior - {doc}`bcf-calibration-at-scale` — calibration methodology and the documented negative result for `compute_channel_correction` - {doc}`glossary` — definitions of CATE, GATE, RPV, tau forest, and other terms used here