JAX Transformations¤
All four glmax verbs — fit, predict,
infer, and check — are
@eqx.filter_jit-wrapped at the public boundary and return JAX-compatible
pytrees.
We provide several practical examples below: batched fits, simulation studies,
bootstrap workflows, differentiable preprocessing, and prediction sensitivity.
Use infer for ordinary GLM standard errors and p-values.
Batched Fitting¤
Use eqx.filter_vmap when fitting many independent models with the same
structure. It is Equinox's vmapped analogue of jax.vmap, and it handles pytrees
with non-array leaves such as families, fitters, and fitted result nouns.
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import glmax
key = jr.key(0)
BATCH, N = 200, 100
X = jnp.column_stack([jnp.ones(N), jr.normal(key, (N,))])
ys = jr.normal(jr.fold_in(key, 1), (BATCH, N))
fitted_batch = eqx.filter_vmap(
lambda y: glmax.fit(glmax.Gaussian(), X, y)
)(ys)
fitted_batch.params.beta.shape # (200, 2)
fitted_batch.converged.shape # (200,)
The returned FittedGLM is itself a batched pytree. Every
array field gains a leading batch dimension.
You can also map over (X, y) jointly, which is useful for cross-validation
folds or simulation studies:
fitted_folds = eqx.filter_vmap(
lambda Xi, yi: glmax.fit(glmax.Gaussian(), Xi, yi)
)(X_folds, y_folds)
Bootstrap Fits¤
Bootstrap workflows are another natural filter_vmap use case. Each resample
fits the same model to a different row sample, and the batched coefficient array
can be summarized directly.
def bootstrap_betas(key, X, y, B=200):
n = y.shape[0]
keys = jr.split(key, B)
def one_resample(k):
idx = jr.choice(k, n, shape=(n,), replace=True)
return glmax.fit(glmax.Poisson(), X[idx], y[idx]).params.beta
return eqx.filter_vmap(one_resample)(keys)
betas = bootstrap_betas(jr.key(42), X, y)
betas.shape # (200, p)
betas.std(axis=0) # bootstrap standard errors
Differentiating predict¤
predict is differentiable without special handling. It computes the linear
predictor and then asks the family for the response-scale mean. For most
families this is the inverse link \(g^{-1}(\eta)\); for grouped Binomial it also
converts probability to expected count.
import jax
from glmax import Params
def total_predicted(beta):
params = Params(beta=beta, disp=fitted.params.disp, aux=fitted.params.aux)
return glmax.predict(fitted.family, params, X).sum()
dpred_dbeta = jax.grad(total_predicted)(fitted.params.beta)
This is a good fit for prediction sensitivity: derivatives with respect to coefficients, design variables, offsets, or continuous preprocessing parameters.
Differentiating Through fit¤
fit supports both forward-mode (jax.jvp) and reverse-mode (jax.grad)
automatic differentiation. The cleanest use case is a scalar downstream loss
that depends on a fitted model, where the differentiated argument is a continuous
upstream quantity.
def held_out_poisson_nll(scale, x_train, y_train, x_test, y_test):
X_train = jnp.column_stack([jnp.ones_like(x_train), scale * x_train])
X_test = jnp.column_stack([jnp.ones_like(x_test), scale * x_test])
fitted = glmax.fit(glmax.Poisson(), X_train, y_train)
mu_test = glmax.predict(fitted.family, fitted.params, X_test)
return jnp.sum(mu_test - y_test * jnp.log(mu_test))
d_loss_dscale = jax.grad(held_out_poisson_nll)(
1.0,
x_train,
y_train,
x_test,
y_test,
)
Here scale is a differentiable preprocessing parameter. The gradient flows
through the training fit, the fitted coefficients, and the held-out prediction
loss.
Do not use AD through fit as GLM inference
Gradients through fit are useful for sensitivity analysis and
differentiable software built on top of glmax. They are not replacements for
standard errors, p-values, or diagnostic summaries. Use infer(...) and
check(...) for those quantities.
JIT Compilation¤
All public verbs are already JIT-compiled. If you call fit inside a larger
JIT-compiled function, you do not need to wrap fit again.
@jax.jit
def pipeline(X, y, X_test):
fitted = glmax.fit(glmax.Poisson(), X, y)
return glmax.predict(fitted.family, fitted.params, X_test)
Tracing and recompilation
The first call traces and compiles the computation. Later calls with the same array shapes are fast. Changing the family or fitter changes static structure and retraces. Changing array shapes also retraces; changing array values does not.
Advanced¤
How fit Differentiates¤
glmax registers a custom JVP rule based on the Implicit Function Theorem. It does not differentiate through every solver iteration.
At the maximum likelihood estimate, the score is zero. Differentiating that condition gives
where \(H = X^\top W X\) is the Fisher information at the converged fit. glmax solves this linear system to get the coefficient tangent.
Data Sensitivities¤
You can use jax.jvp to ask how fitted coefficients move under a small data
perturbation. This is a sensitivity calculation, not ordinary coefficient
inference.
X = jnp.array([[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]])
y = jnp.array([1.0, 2.0, 3.0, 4.0])
dy = jnp.array([1.0, 0.0, 0.0, 0.0])
_, dbeta = jax.jvp(
lambda y_: glmax.fit(glmax.Gaussian(), X, y_).params.beta,
(y,),
(dy,),
)
For discrete responses, derivatives with respect to y should be read as a
continuous relaxation. They can be useful for debugging or influence-style
sensitivity checks, but they are not a standard GLM estimand.
Failure Modes¤
Derivative paths assume a stable local fit
Automatic differentiation through a fit assumes the local fitted solution is well behaved. Gradients are not meaningful if the fit did not converge, the design is rank deficient, or the optimum is near a boundary where the local linearization is unstable.
For derivative-heavy workflows, check fitted.converged, keep design
matrices well conditioned, and prefer continuous upstream parameters over
discrete branching or row-selection logic inside the differentiated
function.