Multi-arm hurdle BCF¶
Pytyche’s joint hurdle estimator (fit_hurdle_bcf(pooling="joint"))
began as a binary-arm estimator (one treatment vs. one control). This page explains how the estimator
generalizes to K = 3..10 arms, why the generalization takes the form
it does, and what the resulting result shapes look like.
Naming note. The joint estimator is
fit_hurdle_bcf(pooling="joint")returningHurdleBCFResult. Earlier pre-release names (fit_joint_hurdle_bcf,JointGPUBCFResult) are retired; the old fit name raises anImportErrorpointing here.
How a multi-arm fit is structured¶
A K-arm fit uses a single shared prognostic forest (the μ forest, common to
every arm) and one treatment forest whose leaves carry a (K−1) vector of
contrast-basis coefficients. Every posterior draw produces a joint sample over
all K−1 treatment-vs-control contrasts for each visitor: the per-arm outcomes
are evaluated from a single multivariate leaf draw and differenced against
control, so the contrasts stay correlated rather than being estimated in
isolation.
That joint structure is the defining property of the estimator, and it is what makes the multi-arm posterior usable for adaptive targeting. Three things follow from it.
Cross-arm prognostic sharing¶
The conversion and severity baseline (the μ forest) is fit once, across all arms. A visitor who received one treatment still informs the baseline used for every other arm, because the prognostic structure is shared; the treatment forest is left to capture only the differences from that baseline. This is a real efficiency gain in the regime adaptive experiments live in, where per-arm sample sizes are small. (Fitting each treatment-vs-control contrast as a separate model would relearn the baseline K−1 times from disjoint slices of the data.)
A selection-bias-free “best arm”¶
Choosing the best arm means taking an argmax across the K−1 contrasts. Computed
within each correlated joint draw, that argmax already accounts for the noise in
the comparison, so the quantity the recommendation engine reads —
P(arm_k best | segment) — is calibrated. Estimating each contrast in isolation
and then taking the max instead would reintroduce the classic winner’s-curse
over-estimation of the selected arm, an effect that grows with K.
Correct joint uncertainty¶
Every contrast is measured against the same control arm, so the K−1 contrasts share structure. The joint posterior carries that correlation, which is what keeps credible-interval widths and Thompson allocation probabilities honest; treating the contrasts as independent would understate the shared uncertainty.
K = 2 dispatch: binary-arm path stays intact¶
The K = 2 binary-arm path is the validated reference implementation with pinned fingerprint tests. Rather than unifying the code paths, the entry point dispatches on K:
K = int(Z.max()) + 1
if K == 2:
<legacy binary-arm code, bit-identical to pre-change>
else:
<new multi-arm code path>
This makes binary-arm fingerprint preservation a structural guarantee: K = 2 inputs never reach the multivariate code. Downstream code that passes binary treatment vectors sees no change. The unification of both paths under the multivariate body is deferred to v0.3+, when the multi-arm path has been the production path long enough to become the reference implementation itself.
Contrast coding¶
For K ≥ 3, _compute_basis(Z) returns an (n, K−1) matrix with
treatment-vs-control reference coding using the same ±0.5 magnitude
as the binary case:
Column j is
+0.5ifZ = j+1(treatment arm j+1).Column j is
−0.5ifZ = 0(control).Column j is
0otherwise.
Control is the baseline, and it loads on every contrast column: a control
visitor’s predicted response is μ − 0.5·Σ_j τ_j, while a treatment-j+1
visitor’s is μ + 0.5·τ_j. Two consequences worth being precise about:
The leaf component
τ_jis not itself the treatment-j+1-vs-control effect (except at K = 2, where the coding collapses to the binary case). The model-implied contrast is0.5·τ_j + 0.5·Σ_k τ_k— the leaf coordinates are coupled.The composed outputs are genuine contrasts: the composition kernel evaluates each arm’s potential outcome from the same draw (control at
μ − 0.5·Στ, arm k atμ + 0.5·τ_{k−1}) and differences them, sorpv_cate_samples[:, :, j]is exactly the treatment-j+1-vs-control RPV CATE — what the downstream recommendation engine (Thompson allocation, graduation-candidate detection) operates on.
The prognostic forest (μ) keeps a scalar ones(n) basis at every K; at K ≥ 3
the tau forest contributes the −0.5·Στ term to control predictions.
The multivariate Normal-Normal conjugate at Stage 1b¶
The tau forest has two leaf draws per tree — one per channel:
Alpha channel (conversion / probit): error scale
σ²_conv = 1.0, per-tree prior precisionτ_conv = 1 / num_trees_tau.Beta channel (severity / log-revenue NN): error scale
σ²_sev = 1 / τ_0, per-tree prior precisionτ_sev = 1 / num_trees_tau.
At K ≥ 3, each leaf draw generalizes from scalar to a (K−1)-vector multivariate Normal-Normal conjugate. Per leaf, per channel:
sum_hrbecomes a(K−1,)vector:Σ basis_visitor^T · residual_visitorover visitors assigned to the leaf.sum_h2becomes a(K−1, K−1)matrix:Σ basis_visitor^T · basis_visitorover the relevant visitors (all for conversion; converters only for severity).Posterior mean:
P^{-1} @ (sum_hr / σ²)whereP = sum_h2 / σ² + (1 / τ_chan) · I.Posterior sample: Cholesky solve of P, then a correlated multivariate-normal draw via
jnp.linalg.
The σ²_sev scaling on the severity sum_h2 matrix must be carried
through — dropping it (treating both channels as σ² = 1) is a silent
miscalibration bug.
The prior precision is fixed and diagonal — (1/τ_chan) · I per channel —
on the leaf components. Because the contrasts mix the components (control
loads on every column), the implied prior on the treatment-vs-control
contrasts grows with K and is positively correlated across contrasts:
per-contrast prior variance 0.25·(K+2)·τ_chan (equal to the binary-arm
prior at K = 2; 3× it at K = 10) with prior correlation (K+1)/(K+2).
This is a shared-effect shrinkage structure — contrasts pool toward each
other a priori — which intensifies at large K. The prior is not sampled;
there is no Inverse-Wishart hyperprior on the leaf covariance in v0.2.
The mu (prognostic) forest stays scalar at every K. Only the tau
forest’s (n, K−1) basis triggers the multivariate leaf body. The
dispatch is on basis.ndim: scalar for the mu forest, matrix for the
tau forest.
Why the binary fast-path shortcut does not survive¶
The binary-arm fast path is efficient because every visitor’s basis is
±0.5, so basis² = 0.25 is a constant and each leaf’s precision
reduces to count × 0.25 — a value that depends only on how many
visitors land in the leaf, never on which arm they came from. This lets
the structure-only factors be precomputed in parallel from counts alone.
With treatment-vs-control coding, a control visitor contributes 0.25
to every entry of the (K−1, K−1) matrix; a treatment-j visitor
contributes 0.25 to a single diagonal entry. Per-leaf sum_h2 is a
genuine matrix that depends on the arm composition of the leaf, not a
count. The count-based shortcut does not apply.
v0.2 ships the correctness-first path: the tau-forest multivariate leaf
draw runs inside the lax.scan, computing the per-leaf Cholesky solve +
correlated draw without hoisting structure-only factors to the parallel
stage. Hoisting is a profile-driven follow-up gated on the K = 10
performance characterization (marked # perf-follow-up in the code).
Severity precision stays a single scalar¶
The global severity precision τ_0 is the precision of the scalar severity
residual per visitor. Each visitor is assigned to exactly one arm and observes
exactly one log-revenue value, so its severity residual
y_log − μ_sev − (τ_sev · basis) is a per-visitor dot product
(K−1) · (K−1) → scalar — there is no (K−1)-vector of severity residuals,
because potential outcomes are not jointly observed.
So τ_0 stays a scalar Gamma draw at every K: tau0_samples is (S_total,)
and sigma2_samples is its reciprocal property, with no matrix precision, no
Wishart sampler, and no positive-definite-matrix invariant. The genuine
(K−1, K−1) coupling in the model lives one level up — in the Stage-1b
leaf-vector posterior covariance, which is full because control visitors load
on every contrast column — not in the severity precision.
The compose-kernel seam and selection-bias-free argmax¶
After the MCMC chains run, _compose_rpv_cate_jax in compose.py
assembles the final result arrays. The key output for the recommendation
engine is:
rpv_cate_samples (n, S_total, K−1)
This is the jointly-sampled contrast posterior for every visitor across
all posterior draws (chains concatenated, S_total = (num_mcmc / thin_factor) · num_chains). The K−1 contrast values per draw per visitor are correlated —
they came from a single multivariate leaf draw — not independent.
The recommendation engine reads P(arm_k is best | segment) from this
posterior. Because the K−1 values are treatment-vs-control contrasts, the
best arm over all K arms (control included) must be resolved per draw with
control handled explicitly: control wins a draw when every contrast is
negative; otherwise the treatment with the largest positive contrast wins.
A bare argmax over the K−1 contrasts would never select control and would
answer “best treatment,” not “best arm.” Counting the correct rule across
draws gives the per-visitor probability each arm is best:
# rpv_cate_samples: (n, S_total, K-1) — contrasts vs control
best_contrast = jnp.argmax(rpv_cate_samples, axis=-1) # (n, S_total)
any_treatment_wins = jnp.any(rpv_cate_samples > 0, axis=-1) # (n, S_total)
# P(control best): all K-1 contrasts negative in the draw
p_control_best = jnp.mean(~any_treatment_wins, axis=1) # (n,)
# P(treatment j+1 best): it is the argmax contrast AND beats control
p_treatment_best = jnp.stack([
jnp.mean(any_treatment_wins & (best_contrast == j), axis=1)
for j in range(K - 1)
], axis=-1) # (n, K-1)
The argmax is computed within each joint draw (the K−1 contrasts came from a
single multivariate leaf draw and are correlated), not across independent
posteriors — so the winner’s-curse bias (problem (b) above) is eliminated.
That joint structure is what makes P(arm_k best) calibrated.
Result shapes¶
Sample arrays follow the visitor-major, chains-concatenated convention at
every K. At K ≥ 3, HurdleBCFResult carries:
rpv_cate_samples (n, S_total, K−1) float32, CPU — always populated;
the jointly-sampled contrast vector
p_samples (n, S_total, K) jax.Array — None unless
retain_channel_samples=True
sev_samples (n, S_total, K) jax.Array — None unless
retain_channel_samples=True
tau0_samples (S_total,) float32 — scalar at every K
sigma2_samples property: 1 / tau0_samples (scalar at every K)
At K = 2, the legacy field names are populated instead:
p0_samples, p1_samples, sev0_samples, sev1_samples (all
(n, S_total)), and rpv_cate_samples (n, S_total) (scalar contrast).
The new per-arm fields are None at K = 2. The legacy paired fields are
None at K ≥ 3. rpv_cate_samples is always populated, with a trailing
arm axis added at K ≥ 3.
tau0_samples and sigma2_samples are shape-invariant in K: scalar
(S_total,) precision and its reciprocal property at every K.
K = 3 worked example¶
This is the executable acceptance target for the K-arm path: a complete
three-arm fit (control + two treatment variants) whose assertions encode the
K = 3 contract. It is registered xfail(strict) in the docs test suite
(tests/test_docs/test_executable_examples.py) until the K ≥ 3 path lands. The
shape assertions fail against today’s binary-collapse behavior — a K = 3 Z
currently routes to the binary path, so rpv_cate_samples comes back (n, S)
rather than (n, S, 2) — and flip the suite green the moment the multi-arm path
is implemented, at which point the xfail entry is removed. The example is the
contract the implementation must satisfy.
import numpy as np
import pandas as pd
from pytyche import fit_hurdle_bcf
from pytyche.contracts import ObservedExperimentData, VariantData
rng = np.random.default_rng(42)
n = 300
# Three arms: 0 = control, 1 = free_shipping, 2 = promo_10pct.
# K is the number of variants in the observed data — no K argument is passed.
Z = rng.integers(0, 3, size=n)
X = rng.standard_normal((n, 5)).astype(np.float32)
# Zero-inflated revenue: arm-specific conversion + log-normal severity.
logit = -0.5 + 0.4 * X[:, 0] + np.array([0.0, 0.5, 1.0])[Z] # treatments lift conversion
p_convert = 1.0 / (1.0 + np.exp(-logit))
converted = rng.random(n) < p_convert
log_rev = np.array([3.4, 3.5, 3.6])[Z] + 0.3 * rng.standard_normal(n) # + severity
Y = np.where(converted, np.exp(log_rev), 0.0).astype(np.float32)
# Package the arrays as observed experiment data: one VariantData per arm,
# visitor frames carrying the schema columns plus the five features.
# (Assignment is uniform, so the adapter's constant 1/K multi-arm
# propensity placeholder matches the design.)
arm_names = ("control", "free_shipping", "promo_10pct")
variants = []
for code, name in enumerate(arm_names):
mask = Z == code
m = int(mask.sum())
frame = pd.DataFrame({
"visitor_id": [f"{name}_v{i}" for i in range(m)],
"experiment_id": "multi-arm-demo",
"variant": name,
"converted": converted[mask],
"revenue": Y[mask].astype(np.float64),
"orders_count": converted[mask].astype(np.int64),
"sessions_count": np.ones(m, dtype=np.int64),
})
for j in range(X.shape[1]):
frame[f"f{j}"] = X[mask, j].astype(np.float64)
variants.append(VariantData(
name=name,
visitors=frame,
n_visitors=m,
n_conversions=int(converted[mask].sum()),
total_revenue=float(Y[mask].sum()),
))
observed = ObservedExperimentData(
experiment_id="multi-arm-demo",
metric="revenue_per_visitor",
variants=variants,
)
result = fit_hurdle_bcf(
observed,
num_trees_mu=20,
num_trees_tau=10,
num_gfr_sweeps=2,
num_burnin=20,
num_mcmc=30,
retain_channel_samples=True,
)
S = result.rpv_cate_samples.shape[1] # S_total = (num_mcmc / thin) * num_chains
# --- The K = 3 contract (these assertions are the acceptance target) ---
# rpv_cate_samples carries a trailing (K-1) = 2 contrast axis at K = 3.
assert result.rpv_cate_samples.shape == (n, S, 2)
# Per-arm channel samples carry a trailing K = 3 arm axis.
assert result.p_samples.shape == (n, S, 3)
assert result.sev_samples.shape == (n, S, 3)
# Severity precision stays scalar at every K.
assert result.tau0_samples.shape == (S,)
p = np.asarray(result.p_samples)
sev = np.asarray(result.sev_samples)
rpv = np.asarray(result.rpv_cate_samples)
# Conversion probabilities live in [0, 1].
assert p.min() >= 0.0 and p.max() <= 1.0
# RPV decomposition identity, per contrast j, per draw (visitor-major):
# rpv[:, :, j] == p[:, :, j+1] * sev[:, :, j+1] - p[:, :, 0] * sev[:, :, 0]
for j in range(2):
recon = p[:, :, j + 1] * sev[:, :, j + 1] - p[:, :, 0] * sev[:, :, 0]
assert np.allclose(rpv[:, :, j], recon, atol=1e-3)
# P(arm_k best | visitor), control included: control wins a draw iff every
# contrast is negative; otherwise the largest positive contrast wins.
best_contrast = np.argmax(rpv, axis=-1) # (n, S)
any_treatment_wins = np.any(rpv > 0, axis=-1) # (n, S)
p_control_best = np.mean(~any_treatment_wins, axis=1) # (n,)
p_treatment_best = np.stack(
[np.mean(any_treatment_wins & (best_contrast == j), axis=1) for j in range(2)],
axis=-1,
) # (n, 2)
assert p_control_best.shape == (n,)
assert p_treatment_best.shape == (n, 2)
# Each visitor's arm-best probabilities (control + 2 treatments) sum to 1.
assert np.allclose(p_control_best + p_treatment_best.sum(axis=1), 1.0)
Why five segments are often enough¶
A core empirical finding supporting the segment-based approach: Zhang &
Misra 2022 (“Coarse Personalization,” arXiv 2204.05793, EC ‘24) found
that in a food-delivery promotions RCT, five discrete segments recover
99.5% of expected incremental profits relative to full personalization.
The policy tree is a near-lossless compression of the CATE surface. This
result was obtained with continuous outcomes and a single experiment;
transfer to zero-inflated multi-arm settings has not been directly studied,
but it informs the coarse-segment framing of pytyche’s recommendation
engine (the sequential-experiment-api change).
GPU GFR warm-start¶
fit_hurdle_bcf(pooling="joint") runs five GPU GFR warm-start sweeps before MCMC
by default (num_gfr_sweeps=5, gfr_backend="gpu"). The GPU GFR path
is the production warm-start — skipping it (num_gfr_sweeps=0) causes
a real degradation in posterior quality, not a graceful fallback.
The GPU GFR tau sweep receives the same (K−1)-vector generalization as
the MCMC Stage-1b path: the same per-leaf multivariate Normal-Normal
conjugate, the same Cholesky solve + correlated draw, the same fixed
diagonal prior. It reuses the shared _log_marginal_leaves and
_hurdle_sufficient_stats* helpers.
The StochTree CPU GFR backend (gfr_backend="cpu", cpu_gfr.py) stays
at K = 2 only. Passing K ≥ 3 with gfr_backend="cpu" raises
NotImplementedError. The production gfr_backend="gpu" path carries
multi-arm.
Performance characteristics¶
Measured on a Quadro RTX 5000 (16 GB) — a rough characterization pass (300 burn-in + 300 MCMC × 2 chains, 50 + 50 trees, 5 GFR sweeps, wall-clock including one-time XLA compile):
K |
n |
wall-clock |
vs. K = 2 |
|---|---|---|---|
2 (binary fast-path) |
10 000 |
62.4 s |
— |
3 (multivariate) |
10 000 |
78.7 s |
1.26× |
10 (multivariate) |
5 000 |
81.4 s |
1.30× at half the |
K = 2: zero performance change (dispatch routes to the unchanged binary-arm code).
K = 3: 1.26× the K = 2 wall-clock — better than the original ≈1.5× estimate. The
(K−1, K−1) = (2, 2)matrix Cholesky inside thelax.scanis cheap at this contrast count.K = 10: ≈ 2.6× per-visitor cost (the
(9, 9)per-leaf Cholesky), but an 81 s 10-arm fit is comfortably tractable. The in-scan per-leaf factorization does not dominate at this scale, so the performance follow-up (hoist the structure-onlysum_h2matrices to the parallel stage; the code marks the site# perf-follow-up) is not triggered by this evidence — file it only if a larger-scale K = 10 run shows it dominating.
The numbers above are a rough pass; a production-scale benchmark
(n = 50 000, num_chains = 4, num_mcmc = 2000) with GPU-memory
capture is deferred. The raw artifact lives at bench/multi_arm_baseline.json.