Skip to content

Stochastic Trace Estimators¤

Given a square linear operator \(\mathbf{A}\), the trace of \(\mathbf{A}\) is defined as,

\[\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii}.\]

When \(\mathbf{A}\) is represented in memory as an \(n \times n\) matrix, computing the trace is straightforward, only requiring \(O(n)\) time to sum along the diagonal. However, in practice, \(\mathbf{A}\) can be the result of many operations and explicit calculation and represention of \(\mathbf{A}\) may be prohibitive.

Given this, we may represent \(\mathbf{A}\) as a linear operator, which can be viewed as a lazy representation of \(\mathbf{A}\) that only tracks the underlying operations to calculate its final result. As such, matrix vector products between \(\mathbf{A}\) and vector \(\omega\) can be obtained by lazily evaluating the chain of underlying composition operations with intermediate matrix-vector products (e.g., lineax).

There is a rich history of stochastic trace estimation for matrices where one can estimate the trace of \(\mathbf{A}\) using multiple matrix-vector products followed by averaging. To see this, observe that

\[\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A}),\]

where \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\). The above is known as the Girard-Hutchinson estimator. There have been multiple advancements in stochastic trace estimation. Here, traceax aims to provide an easy-to-use API for stochastic trace estimation that leverages the flexibility of lineax linear operators together with differentiable and performant JAX based numerics.

traceax.AbstractTraceEstimator

traceax.AbstractTraceEstimator ¤

Abstract base class for all trace estimators.

estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]] abstractmethod ¤

Estimate the trace of operator.

Example

key = jax.random.PRNGKey(...)
operator = lx.MatrixLinearOperator(...)
hutch = tx.HutchinsonEstimator()
result = hutch.compute(key, operator, k=10)
#  or
result = hutch(key, operator, k=10)

Arguments:

  • key: the PRNG key used as the random key for sampling.
  • operator: the (square) linear operator for which the trace is to be estimated.
  • k: the number of matrix vector operations to perform for trace estimation.

Returns:

A two-tuple of:

  • The trace estimate.
  • A dictionary of any extra statistics above the trace, e.g., the standard error.
__call__(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]] ¤

An alias for estimate.

traceax.HutchinsonEstimator(AbstractTraceEstimator) ¤

Girard-Hutchinson Trace Estimator:

\(\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})\), where \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).

__init__(sampler: AbstractSampler = RademacherSampler()) ¤

Arguments:


traceax.HutchPlusPlusEstimator(AbstractTraceEstimator) ¤

Hutch++ Trace Estimator:

Let \(\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}\) be the a low-rank approximation to \(\mathbf{A}\), where \(\mathbf{Q}\) is the orthonormal basis of \(\mathbf{A} \Omega\), for \(\Omega = [\omega_1, \dotsc, \omega_k]\).

Hutch++ improves upon Girard-Hutchinson estimator by including the trace of the residuals. Namely, Hutch++ estimates \(\text{trace}(\mathbf{A})\) as \(\text{trace}(\hat{\mathbf{A}}) - \text{trace}(\mathbf{A} - \hat{\mathbf{A}})\).

As with the Girard-Hutchinson estimator, it requires \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).

__init__(sampler: AbstractSampler = RademacherSampler()) ¤

Arguments:


traceax.XTraceEstimator(AbstractTraceEstimator) ¤

XTrace Trace Estimator:

Let \(\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}\) be the the low-rank approximation to \(\mathbf{A}\), where \(\mathbf{Q}\) is the orthonormal basis of \(\mathbf{A} \Omega\), for \(\Omega = [\omega_1, \dotsc, \omega_k]\).

XTrace improves upon Hutch++ estimator by enforcing exchangeability of sampled test-vectors, to construct a symmetric estimation function with lower variance.

Additionally, the improved XTrace algorithm (i.e. improved = True), ensures that test-vectors are orthogonalized against the low rank approximation \(\mathbf{Q}\mathbf{Q}^* \mathbf{A}\) and renormalized. This improved XTrace approach may provide better empirical results compared with the non-orthogonalized version.

As with the Girard-Hutchinson estimator, it requires \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).

__init__(sampler: AbstractSampler = SphereSampler(), improved: bool = True) ¤

Arguments:

  • sampler: the sampling distribution for \(\omega\). Default is traceax.SphereSampler.
  • improved: whether to use the improved XTrace estimator, which rescales predicted samples. Default is True (see Notes).

traceax.XNysTraceEstimator(AbstractTraceEstimator) ¤

XNysTrace Trace Estimator:

XNysTrace improves upon XTrace estimator when \(\mathbf{A}\) is (negative-) positive-semidefinite, by performing a Nyström approximation, rather than a randomized SVD (i.e., random projection followed by QR decomposition).

Like, traceax.XTraceEstimator, the improved XNysTrace algorithm (i.e. improved = True), ensures that test-vectors are orthogonalized against the low rank approximation and renormalized. This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version.

As with the Girard-Hutchinson estimator, it requires \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).

__init__(sampler: AbstractSampler = SphereSampler(), improved: bool = True) ¤

Arguments:

  • sampler: the sampling distribution for \(\omega\). Default is traceax.SphereSampler.
  • improved: whether to use the improved XNysTrace estimator, which rescales predicted samples. Default is True (see Notes).