JAX Transformations¤
All four glmax verbs — fit, predict, infer, and check — are @eqx.filter_jit-wrapped at the public boundary and return JAX-compatible pytrees. This means the full suite of JAX transforms works on them.
Batched fitting with filter_vmap¤
eqx.filter_vmap is equinox's vmapped analogue of jax.vmap, extended to handle pytrees and non-array leaves. A natural use case is fitting the same model to many response vectors simultaneously — bootstrapping, permutation testing, or multiple phenotypes with a shared design matrix.
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import glmax
key = jr.key(0)
BATCH, N = 200, 100
X = jnp.column_stack([jnp.ones(N), jr.normal(key, (N,))])
ys = jr.normal(jr.fold_in(key, 1), (BATCH, N)) # 200 response vectors
fitted_batch = eqx.filter_vmap(
lambda y: glmax.fit(glmax.Gaussian(), X, y)
)(ys)
fitted_batch.params.beta.shape # (200, 2)
fitted_batch.converged.shape # (200,)
All 200 fits run in parallel. The returned FittedGLM is itself a batched pytree: every array field gains a leading batch dimension.
You can also vmap over (X, y) jointly, which is useful for cross-validation folds or simulation studies:
fitted_folds = eqx.filter_vmap(
lambda Xi, yi: glmax.fit(glmax.Gaussian(), Xi, yi)
)(X_folds, y_folds)
Differentiating through fit¤
fit is a JAX-differentiable primitive. Both forward-mode (jax.jvp) and reverse-mode (jax.grad) work.
Under the hood, glmax registers a custom JVP rule based on the Implicit Function Theorem. Rather than differentiating through the solver iterations, glmax computes the tangent analytically: at the MLE the score is zero, and differentiating that condition gives \(H \, d\hat\beta = -\partial_{\text{data}}(\nabla_\beta \ell) \cdot d(\text{data})\), where \(H = X^\top W X\) is the Fisher information at the converged fit. This avoids accumulating gradients through potentially hundreds of solver steps.
Forward-mode: sensitivities of β to data perturbations¤
Use jax.jvp when you want directional derivatives — for example, how \(\hat\beta\) shifts if the response vector moves in some direction:
import jax
X = jnp.array([[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]])
y = jnp.array([1.0, 2.0, 3.0, 4.0])
# Direction: perturb the first observation
dy = jnp.array([1.0, 0.0, 0.0, 0.0])
_, dbeta = jax.jvp(
lambda y_: glmax.fit(glmax.Gaussian(), X, y_).params.beta,
(y,),
(dy,),
)
# dbeta[i] = ∂β̂ᵢ / ∂y₁ (sensitivity of each coefficient to y₁)
Reverse-mode: gradients of a scalar loss¤
Use jax.grad when you have a scalar loss that depends on the fitted coefficients — for instance, a held-out log-likelihood or a regularisation term:
def held_out_loglik(y_train, X_train, y_test, X_test):
fitted = glmax.fit(glmax.Poisson(), X_train, y_train)
mu_test = glmax.predict(fitted.family, fitted.params, X_test)
return -jnp.sum(y_test * jnp.log(mu_test) - mu_test) # Poisson NLL
# Differentiate w.r.t. y_train and X_train (argnums 0 and 1)
grad_fn = jax.grad(held_out_loglik, argnums=(0, 1))
g_y, g_X = grad_fn(y_train, X_train, y_test, X_test)
jax.value_and_grad also works:
loss, (g_y, g_X) = jax.value_and_grad(
held_out_loglik, argnums=(0, 1)
)(y_train, X_train, y_test, X_test)
Differentiating predict¤
predict is fully differentiable without any special handling — it's just a matrix multiply and a link inverse. Gradients with respect to beta or X work out of the box:
from glmax import Params
def total_predicted(beta):
params = Params(beta=beta, disp=fitted.params.disp, aux=fitted.params.aux)
return glmax.predict(fitted.family, params, X).sum()
dpred_dbeta = jax.grad(total_predicted)(fitted.params.beta)
JIT compilation¤
All verbs are JIT-compiled on first call. The family instance and fitter strategy are treated as static structure — changing them forces a retrace, while changing X, y, or other arrays does not.
First-call latency
The first call to any verb traces and compiles the computation. This can take a few seconds on the first run but is a one-time cost — subsequent calls with the same array shapes are fast. If you're benchmarking, always time the second call.
If you're calling fit inside a larger JIT-compiled function, there's no need to double-wrap:
@jax.jit
def pipeline(X, y, X_test):
fitted = glmax.fit(glmax.Poisson(), X, y)
return glmax.predict(fitted.family, fitted.params, X_test)
The first call traces and compiles. Subsequent calls with the same shapes are fast.
Combining transforms¤
Transforms compose. For example, to compute bootstrap standard errors by vmapping over resampled response vectors and then taking gradients:
import jax.random as jr
def bootstrap_betas(key, X, y, B=200):
n = y.shape[0]
keys = jr.split(key, B)
def one_resample(k):
idx = jr.choice(k, n, shape=(n,), replace=True)
return glmax.fit(glmax.Poisson(), X[idx], y[idx]).params.beta
return eqx.filter_vmap(one_resample)(keys)
betas = bootstrap_betas(jr.key(42), X, y)
betas.shape # (200, p)
betas.std(axis=0) # bootstrap standard errors