Skip to content

glmax¤

glmax is a JAX-based generalized linear modeling library.

import jax.numpy as jnp
import glmax

X = jnp.array([[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]])
y = jnp.array([0.0, 1.0, 1.0, 2.0])

fitted   = glmax.fit(glmax.Poisson(), X, y)
mu_hat   = glmax.predict(fitted.family, fitted.params, X)
result   = glmax.infer(fitted)
residuals = glmax.check(fitted, diagnostic=glmax.PearsonResidual())

Four verbs — fit, predict, infer, and check — cover the full modeling workflow. Each takes explicit inputs and returns an explicit result. No hidden state is threaded between calls.

See the Overview to get started, or From statsmodels if you're migrating from statsmodels.