Skip to content

Documentation-webpage PyPI-Server Github License Project generated with Hatch

Traceax¤

traceax is a Python library to perform stochastic trace estimation for linear operators. Namely, given a square linear operator \(\mathbf{A}\), traceax provides flexible routines that estimate,

\[\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},\]

using only matrix-vector products. traceax is heavily inspired by lineax as well as XTrace.

Installation | Example | Notes | Support | Other Software


Installation¤

Users can download the latest repository and then use pip:

git clone https://github.com/mancusolab/traceax.git
cd traceax
pip install .

Get Started with Example¤

import jax.numpy as jnp
import jax.random as rdm
import lineax as lx

import traceax as tx

# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
N = 1000
key = rdm.PRNGKey(seed)
key, xkey = rdm.split(key)

X = rdm.normal(xkey, (N, N))
Q, R = jnp.linalg.qr(X)
U = jnp.power(0.7, jnp.arange(N))
A = (Q * U) @ Q.T

# should be numerically close
print(jnp.trace(A))  # 3.3333323
print(jnp.sum(U))  # 3.3333335

# setup linear operator
operator = lx.MatrixLinearOperator(A)

# number of matrix vector operators
k = 25

# split key for estimators
key, key1, key2, key3, key4 = rdm.split(key, 5)

# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k))  # (Array(3.6007538, dtype=float32), {})

# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k))  # (Array(3.4094956, dtype=float32), {})

# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k))  # (Array(3.3030486, dtype=float32), {'std.err': Array(0.01238528, dtype=float32)})

# XNysTrace estimator; Improved performance for NSD/PSD trace estimates
operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)
nt = tx.XNysTraceEstimator()
print(nt.estimate(key4, operator, k))  # (Array(3.3314352, dtype=float32), {'std.err': Array(0.0006521, dtype=float32)})

Notes¤

  • traceax uses JAX with Just In Time compilation to achieve high-speed computation. However, there are some issues for JAX with Mac M1 chip. To solve this, users need to initiate conda using miniforge, and then install traceax using pip in the desired environment.

Support¤

Please report any bugs or feature requests in the Issue Tracker. If users have any questions or comments, please contact Linda Serafin (lserafin@usc.edu) or Nicholas Mancuso (nmancuso@usc.edu).

Other Software¤

Feel free to use other software developed by Mancuso Lab:

  • SuShiE: a Bayesian fine-mapping framework for molecular QTL data across multiple ancestries.
  • MA-FOCUS: a Bayesian fine-mapping framework using TWAS statistics across multiple ancestries to identify the causal genes for complex traits.
  • SuSiE-PCA: a scalable Bayesian variable selection technique for sparse principal component analysis
  • twas_sim: a Python software to simulate TWAS statistics.
  • FactorGo: a scalable variational factor analysis model that learns pleiotropic factors from GWAS summary statistics.
  • HAMSTA: a Python software to estimate heritability explained by local ancestry data from admixture mapping summary statistics.

traceax is distributed under the terms of the Apache-2.0 license.


This project has been set up using Hatch. For details and usage information on Hatch see https://github.com/pypa/hatch.