Skip to content

Model Prediction¤

glmax.predict(...) applies a model specification and fitted parameters to data and returns response-scale mean predictions. The high-level philosophy is that prediction should stay explicit about both the model and the parameter carrier rather than hiding state inside a fitted object method.

Prediction returns means on the response scale

predict(...) returns \(\mathrm{E}[Y \mid X]\) on the same scale as the response vector y passed to fit(...). For most families this is exactly the inverse link \(g^{-1}(\eta)\). For grouped-response families there can be one extra family-specific conversion.

For example, Binomial(n_trials=10) models success counts. The inverse logit gives the success probability \(p = g^{-1}(\eta)\), while predict(...) returns the expected count \(\mu = 10p\). If you want probabilities from grouped Binomial predictions, divide by n_trials.

Exposure changes counts, not rates

For log-link count models, use offset=jnp.log(exposure). Then predict(...) returns expected counts for the supplied exposure: \(\mu = \mathrm{exposure} \cdot \exp(X\beta)\). Divide by exposure to get rates.

Offset is already on the linear-predictor scale

Do not pass raw exposure as offset. offset is added to \(\eta\), so exposure workflows should pass log(exposure).

glmax.predict(family: glmax.ExponentialDispersionFamily, params: glmax.Params, X: ArrayLike, *, offset: ArrayLike | None = None) -> jax.Array ¤

Apply a fitted family to new data and return predicted means.

This is the canonical predict grammar verb. It is @eqx.filter_jit-wrapped. Prediction computes the family-specific response-scale mean for \(\eta = X \hat{\beta} + o\), where \(X\) is the design matrix, \(\hat{\beta}\) is the fitted coefficient vector, and \(o\) is the optional offset.

Arguments:

  • family: glmax.ExponentialDispersionFamily instance.
  • params: fitted glmax.Params (for example fitted.params from glmax.fit).
  • X: covariate matrix, shape (n, p).
  • offset: optional offset vector added to the linear predictor. Use the same convention as fit(...); for log-link count exposure, pass log(exposure).

Returns:

Predicted mean response vector, shape (n,), on the same scale as y.

Raises:

  • TypeError: if family, params, or X have wrong types.