Stochastic Trace Estimators¤
Given a square linear operator \(\mathbf{A}\), the trace
of \(\mathbf{A}\) is defined as,
When \(\mathbf{A}\) is represented in memory as an \(n \times n\) matrix, computing
the trace
is straightforward, only requiring \(O(n)\) time to sum along the
diagonal. However, in practice, \(\mathbf{A}\) can be the result of many operations
and explicit calculation and represention of \(\mathbf{A}\) may be prohibitive.
Given this, we may represent \(\mathbf{A}\) as a linear operator, which can be viewed as a lazy representation of \(\mathbf{A}\) that only tracks the underlying operations to calculate its final result. As such, matrix vector products between \(\mathbf{A}\) and vector \(\omega\) can be obtained by lazily evaluating the chain of underlying composition operations with intermediate matrix-vector products (e.g., lineax).
There is a rich history of stochastic trace estimation for matrices where
one can estimate the trace
of \(\mathbf{A}\) using multiple matrix-vector products
followed by averaging. To see this, observe that
where \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).
The above is known as the Girard-Hutchinson estimator. There have been multiple
advancements in stochastic trace
estimation. Here, traceax
aims to provide an
easy-to-use API for stochastic trace estimation that leverages the flexibility of
lineax linear operators together with
differentiable and performant JAX based numerics.
traceax.AbstractTraceEstimator
traceax.AbstractTraceEstimator
¤
Abstract base class for all trace estimators.
estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]
abstractmethod
¤
Estimate the trace of operator
.
Example
key = jax.random.PRNGKey(...)
operator = lx.MatrixLinearOperator(...)
hutch = tx.HutchinsonEstimator()
result = hutch.compute(key, operator, k=10)
# or
result = hutch(key, operator, k=10)
Arguments:
key
: the PRNG key used as the random key for sampling.operator
: the (square) linear operator for which the trace is to be estimated.k
: the number of matrix vector operations to perform for trace estimation.
Returns:
A two-tuple of:
- The trace estimate.
- A dictionary of any extra statistics above the trace, e.g., the standard error.
__call__(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]
¤
An alias for estimate
.
traceax.HutchinsonEstimator(AbstractTraceEstimator)
¤
Girard-Hutchinson Trace Estimator:
\(\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})\), where \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).
__init__(sampler: AbstractSampler = RademacherSampler())
¤
Arguments:
sampler
: The sampling distribution for \(\omega\). Default istraceax.RademacherSampler
.
traceax.HutchPlusPlusEstimator(AbstractTraceEstimator)
¤
Hutch++ Trace Estimator:
Let \(\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}\) be the a low-rank approximation to \(\mathbf{A}\), where \(\mathbf{Q}\) is the orthonormal basis of \(\mathbf{A} \Omega\), for \(\Omega = [\omega_1, \dotsc, \omega_k]\).
Hutch++ improves upon Girard-Hutchinson estimator by including the trace of the residuals. Namely, Hutch++ estimates \(\text{trace}(\mathbf{A})\) as \(\text{trace}(\hat{\mathbf{A}}) - \text{trace}(\mathbf{A} - \hat{\mathbf{A}})\).
As with the Girard-Hutchinson estimator, it requires \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).
__init__(sampler: AbstractSampler = RademacherSampler())
¤
Arguments:
sampler
: The sampling distribution for \(\omega\). Default istraceax.RademacherSampler
.
traceax.XTraceEstimator(AbstractTraceEstimator)
¤
XTrace Trace Estimator:
Let \(\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}\) be the the low-rank approximation to \(\mathbf{A}\), where \(\mathbf{Q}\) is the orthonormal basis of \(\mathbf{A} \Omega\), for \(\Omega = [\omega_1, \dotsc, \omega_k]\).
XTrace improves upon Hutch++ estimator by enforcing exchangeability of sampled test-vectors, to construct a symmetric estimation function with lower variance.
Additionally, the improved XTrace algorithm (i.e. improved = True
), ensures that test-vectors
are orthogonalized against the low rank approximation \(\mathbf{Q}\mathbf{Q}^* \mathbf{A}\) and
renormalized. This improved XTrace approach may provide better empirical results compared with
the non-orthogonalized version.
As with the Girard-Hutchinson estimator, it requires \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).
__init__(sampler: AbstractSampler = SphereSampler(), improved: bool = True)
¤
Arguments:
sampler
: the sampling distribution for \(\omega\). Default istraceax.SphereSampler
.improved
: whether to use the improved XTrace estimator, which rescales predicted samples. Default isTrue
(see Notes).
traceax.XNysTraceEstimator(AbstractTraceEstimator)
¤
XNysTrace Trace Estimator:
XNysTrace improves upon XTrace estimator when \(\mathbf{A}\) is (negative-) positive-semidefinite, by performing a Nyström approximation, rather than a randomized SVD (i.e., random projection followed by QR decomposition).
Like, traceax.XTraceEstimator
, the improved XNysTrace algorithm (i.e. improved = True
), ensures
that test-vectors are orthogonalized against the low rank approximation and renormalized.
This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version.
As with the Girard-Hutchinson estimator, it requires \(\mathbb{E}[\omega] = 0\) and \(\mathbb{E}[\omega \omega^T] = \mathbf{I}\).
__init__(sampler: AbstractSampler = SphereSampler(), improved: bool = True)
¤
Arguments:
sampler
: the sampling distribution for \(\omega\). Default istraceax.SphereSampler
.improved
: whether to use the improved XNysTrace estimator, which rescales predicted samples. Default isTrue
(see Notes).