Getting Started¤
glmax uses four verbs — fit, predict, infer, and check — that pass explicit nouns between them. This page walks through each verb in turn.
Design philosophy
glmax is built around a simple idea: the API should be a grammar of verbs that operate on nouns. Verbs are plain functions (fit, predict, infer, check). Nouns are immutable containers for results (FittedGLM, InferenceResult, and so on). Nothing is hidden inside a model object. You can inspect, pass, store, or vmap over any noun directly.
If you're coming from statsmodels, the From statsmodels page maps the two APIs side by side.
Fitting a model¤
You always start with fit. It takes a family, a design matrix X, and a response vector y, optimises the negative log-likelihood to convergence, and returns a FittedGLM that holds everything produced during fitting.
import jax.numpy as jnp
import glmax
X = jnp.array([[1.0, 0.3],
[1.0, 1.1],
[1.0, 2.4],
[1.0, 3.0]])
y = jnp.array([0.5, 1.8, 2.3, 3.1])
fitted = glmax.fit(glmax.Gaussian(), X, y)
You can access fit artifacts directly from the returned noun:
fitted.params.beta # coefficient vector, shape (p,)
fitted.params.disp # dispersion φ
fitted.eta # linear predictor Xβ, shape (n,)
fitted.mu # fitted means E[y | X] on the response scale, shape (n,)
fitted.converged # True if the fitter converged within tolerance
fitted.num_iters # number of iterations taken
No automatic intercept
glmax doesn't add an intercept automatically. Include a column of ones in X if you want one, as in the example above. This is intentional — it keeps the design matrix explicit and avoids surprises when you're controlling exactly which covariates are included.
Choosing a family and link¤
The family encodes the response distribution and variance function. The link
function maps the linear predictor to the family parameter used by the model.
For most families, that inverse-link value is already the response mean. For
grouped Binomial counts, the inverse link gives a probability and the response
mean is n_trials * probability. glmax gives every family a sensible default
link, but you can override it.
# Poisson regression — log link by default
fitted = glmax.fit(glmax.Poisson(), X_count, y_count)
# Binomial with probit link instead of the default logit
fitted = glmax.fit(glmax.Binomial(glmax.ProbitLink()), X_bin, y_bin)
# Gamma regression — inverse link by default
fitted = glmax.fit(glmax.Gamma(), X_pos, y_pos)
# Negative Binomial — overdispersion α is estimated and stored in fitted.params.aux
fitted = glmax.fit(glmax.NegativeBinomial(), X_count, y_count)
alpha = fitted.params.aux
See Families & Links for the full list.
Making predictions¤
Once you have a fitted model, predict applies it to any design matrix. It doesn't need the full FittedGLM — just the family and the parameters — so you can also use it for prediction with hand-constructed coefficients or from warm-starting experiments.
predict and fitted.mu are response-scale means
predict(...) returns the same kind of mean stored in fitted.mu:
\(\mathrm{E}[Y \mid X]\) on the same scale as y. For Poisson with a log
link, this is \(\exp(\eta)\). For grouped Binomial with
Binomial(n_trials=10), the inverse link gives probability
\(p = g^{-1}(\eta)\), but predict(...) returns expected success counts
10 * p.
# In-sample fitted means (same as fitted.mu)
mu_hat = glmax.predict(fitted.family, fitted.params, X)
# Out-of-sample predictions
X_new = jnp.array([[1.0, 1.5], [1.0, 2.0]])
mu_new = glmax.predict(fitted.family, fitted.params, X_new)
Pass offset if your model has an exposure or other additive term in the linear predictor:
mu_new = glmax.predict(fitted.family, fitted.params, X_new, offset=log_exposure_new)
For log-link count models, this prediction is an expected count for the supplied exposure. Divide by exposure if you want a rate.
Inference on coefficients¤
infer takes the fitted noun and returns InferenceResult: coefficient estimates, standard errors, test statistics, and p-values. No refitting happens.
result = glmax.infer(fitted)
result.params.beta # same as fitted.params.beta
result.se # standard errors, shape (p,)
result.stat # test statistics, shape (p,)
result.p # two-sided p-values, shape (p,)
The default is a Wald test with Fisher information standard errors. You can swap either component independently:
# Score test with sandwich (Huber) standard errors
result = glmax.infer(
fitted,
inferrer=glmax.ScoreTest(),
stderr=glmax.HuberError(),
)
Huber standard errors are useful when you're uncertain about the variance function or have overdispersion you don't want to model explicitly. See Inference strategies for the full set of options.
Diagnosing the fit¤
check applies a diagnostic to the fitted noun and returns a typed result. You choose the diagnostic explicitly rather than getting a bundle of everything at once.
Residual diagnostics return an array of the same shape as y:
pearson = glmax.check(fitted, diagnostic=glmax.PearsonResidual())
deviance = glmax.check(fitted, diagnostic=glmax.DevianceResidual())
quantile = glmax.check(fitted, diagnostic=glmax.QuantileResidual())
Quantile residuals are randomised probability integral transform residuals — they're the right choice for discrete families like Poisson and Binomial, where simpler residuals don't follow a clean reference distribution.
GoodnessOfFit and Influence return structured result nouns:
gof = glmax.check(fitted, diagnostic=glmax.GoodnessOfFit())
gof.pearson_chi2 # Pearson χ² statistic
gof.deviance # residual deviance
gof.df_resid # residual degrees of freedom
gof.aic # Akaike information criterion
influence = glmax.check(fitted, diagnostic=glmax.Influence())
influence.hat # leverage values (diagonal of hat matrix)
influence.cooks_d # Cook's distance
influence.dffits # DFFITS
Switching fit strategies¤
fit defaults to IRLSFitter, which solves a sequence of weighted least-squares problems. For problems where IRLS overshoots — non-canonical links, near-boundary means — NewtonFitter may converge more reliably. It uses a backtracking Armijo line search to control step size.
# Fisher scoring Newton with automatic line search
fitted = glmax.fit(glmax.Poisson(), X, y, fitter=glmax.NewtonFitter())
Both strategies return the same FittedGLM noun. You can also tune tolerances or swap the underlying
linear solver:
import lineax as lx
fitter = glmax.IRLSFitter(solver=lx.QR(), tol=1e-6, max_iter=500)
fitted = glmax.fit(glmax.Gamma(), X, y, fitter=fitter)
The default lineax.Cholesky() solver is fastest for small-to-medium problems. lineax.QR() handles rank-deficient designs more gracefully.
Offsets and warm-starting¤
An offset is a fixed term added to the linear predictor before converting \(\eta\) to the response-scale mean. The classic use case is rate modeling in Poisson regression, where the offset is the log of exposure time or population size.
import jax.numpy as jnp
# log(exposure) added to the linear predictor: log(μ) = Xβ + offset
fitted = glmax.fit(glmax.Poisson(), X, y, offset=jnp.log(exposure))
This gives
predict(...) returns \(\mu\), the expected count for the supplied exposure. To
report rates, divide by exposure.
Exposure is not a weight
Exposure is part of the mean model. It changes the expected count by adding
log(exposure) to the linear predictor. Observation weights would change how
rows contribute to fitting; they are intentionally not part of the core
grammar.
Warm-starting lets you seed the solver with parameters from a previous fit. This is useful when refitting the same model on updated data, or when you want to continue from a partially converged solution.
# First fit
fitted = glmax.fit(glmax.Poisson(), X, y)
# Refit on new data, starting from the previous solution
fitted2 = glmax.fit(glmax.Poisson(), X, y_new, init=fitted.params)
JAX transformations¤
fit, predict, infer, and check are all JIT-compiled by default and compatible with JAX transforms. See JAX Transformations for batched fitting, gradients through the fit, and other transform patterns.