SuSiE-PCA¶
SuSiE PCA is a scalable Bayesian variable selection technique for sparse principal component analysis
SuSiE PCA is the abbreviation for the Sum of Single Effects model [1] for principal component analysis. We develop SuSiE PCA for an efficient variable selection in PCA when dealing with high dimensional data with sparsity, and for quantifying uncertainty of contributing features for each latent component through posterior inclusion probabilities (PIPs). We implement the model with the JAX library developed by Google which enable the fast training on CPU, GPU or TPU.
If you enjoy/use our software, please consider citing its publication in iScience (2023),
Yuan, D. and Mancuso, N., 2023. SuSiE PCA: A scalable Bayesian variable selection technique for principal component analysis. iScience, 26(11). DOI: https://doi.org/10.1016/j.isci.2023.108181
Documentation | Installation | Example | Notes | References | Support
Model Description¶
We extend the Sum of Single Effects model (i.e. SuSiE) [1] to principal component analysis. Assume $X_{N \times P}$ is the observed data, $Z_{N \times K}$ is the latent factors, and $W_{K \times P}$ is the factor loading matrix, then the SuSiE PCA model is given by:
$$X | Z,W \sim \mathcal{MN}_{N,P}(ZW, I_N, \sigma^2 I_P)$$
where the $\mathcal{MN}_{N,P}$ is the matrix normal distribution with dimension $N \times P$, mean $ZW$, row-covariance $I_N$, and column-covariance $I_P$. The column vector of $Z$ follows a standard normal distribution. The above model setting is the same as the Probabilistic PCA [2]. The most distinguished part is that we integrate the SuSiE setting into the row vector $\mathbf{w}_k$ of factor loading matrix $W$, such that each $\mathbf{w}_k$ only contains at most $L$ number of non-zero effects. That is, $$\mathbf{w}_k = \sum_{l=1}^L \mathbf{w}_{kl} $$ $$\mathbf{w}_{kl} = w_{kl} \gamma_{kl}$$ $$w_{kl} \sim \mathcal{N}(0,\sigma^2_{0kl})$$ $$\gamma_{kl} | \pi \sim \text{Multi}(1,\pi) $$
Notice that each row vector $\mathbf{w}_k$ is a sum of single effect vector $\mathbf{w}_{kl}$, which is length $P$ vector contains only one non-zero effect $w_{kl}$ and zero elsewhere. And the coordinate of the non-zero effect is determined by $\gamma_{kl}$ that follows a multinomial distribution with parameter $\pi$. By construction, each factor inferred from the SuSiE PCA will have at most $L$ number of associated features from the original data. Moreover, we can quantify the probability of the strength of association through the posterior inclusion probabilities (PIPs). Suppose the posterior distribution of $\gamma_{kl} \sim \text{Multi}(1,\mathbf{\alpha}_{kl})$, then the probability the feature $i$ contributing to the factor $\mathbf{w}_k$ is given by: $$\text{PIP}_{ki} = 1-\prod_{l=1}^L (1 - \alpha_{kli})$$ where the $\alpha_{kli}$ is the $i_{th}$ entry of the $\mathbf{\alpha}_{kl}$.
Install SuSiE PCA¶
The source code for SuSiE PCA is written fully in Python 3.8 with JAX (see JAX installation guide for JAX). Follow the code provided below to quickly get started using SuSiE PCA. Users can clone this github repository and install the SuSiE PCA. (Pypi installation will be supported soon).
git clone https://github.com/mancusolab/susiepca.git
cd susiepca
pip install -e .
Get Started with Example¶
Create a python environment in the cloned repository, then simply import the SuSiE PCA
import susiepca as sp
Generate a simulation data set according to the description in Simulation section from our paper. $Z_{N \times K}$ is the simulated factors matrix, $W_{K \times P}$ is the simulated loading matrix, and the $X_{N \times P}$ is the simulation data set that has $N$ observations with $P$ features.
Z, W, X = sp.sim.generate_sim(seed = 0, l_dim = 40, n_dim = 150, p_dim =200, z_dim = 4, effect_size = 1)
Input the simulation data set into SuSiE PCA with number of component $K=4$ and number of single effects in each component $L=40$, or you can manipulate with those two parameters to check the model mis-specification performance. By default the data is not centered nor scaled, and the max iteration is set to be 200. Here we use the principal components extracted from traditional PCA results as the initialization of mean of $Z$.
results = sp.infer.susie_pca(X, z_dim = 4, l_dim = 40, max_iter=200)
The returned “results” contain 5 different objects:
params: an dictionary that saves all the updated parameters from the SuSiE PCA.
elbo_res: the value of evidence lower bound (ELBO) from the last iteration.
pve: a length $K$ ndarray contains the percent of variance explained (PVE) by each component.
pip: the $K$ by $P$ ndarray that contains the posterior inclusion probabilities (PIPs) of each feature contribution to the factor.
W: the posterior mean of loading matrix which is also a $K$ by $P$ ndarray.
To examine the model performance, one straitforward way is to draw and compare the heatmap of the true loading matrix and estimate loading matrix using seaborn:
import seaborn as sns
# specify the palatte for heatmap
div = sns.diverging_palette(250, 10, as_cmap=True)
# Heatmap of true loading matrix
sns.heatmap(W, cmap = div, fmt = ".2f",center = 0)
# Heatmap of estimate loading matrix
W_hat = results.W
sns.heatmap(W_hat, cmap = div, fmt = ".2f", center = 0)
# Heatmap of PIPs
pip = results.pip
sns.heatmap(pip, cmap = div, fmt = ".2f", center = 0)
To mathmatically compute the Procrustes error of the estimate loading matrix, you need to install the Procruste package to solve the rotation problem (see procrustes installation guide for Procrustes method). Once the loading matrix is rotated to its original direction, one can compute the Procrustes error and look at heatmap as following:
import procrustes
import numpy as np
# perform procrustes transformation
proc_trans_susie = procrustes.orthogonal(np.asarray(W_hat.T), np.asarray(W.T), scale=True)
print(f"The Procrustes error for the loading matrix is {proc_trans_susie.error}")
# Heatmap of transformed loading matrix
W_trans = proc_trans_susie.t.T @ W_hat
sns.heatmap(W_trans, cmap = div, fmt = ".2f", center = 0)
You can also calculate the relative root mean square error (RRMSE) to assess the model prediction performance
from susiepca import metrics
# compute the predicted data
X_hat = results.params.mu_z @ W_hat
# compute the RRMSE
rrmse_susie = metrics.mse(X, X_hat)
Finally we also provide a neat function to compute a $\rho-$ level credible sets (CS). The cs returned by the function is composed of $L \times K$ credible sets, each of them contain a subset of variables that cumulatively explain at least $\rho$ of the posterior density.
cs = sp.metrics.get_credset(results.params.alpha, rho=0.9)
Notes¶
JAX uses 32-bit precision by default. To enable 64-bit precision before calling susiepca add the following code:
import jax
jax.config.update("jax_enable_x64", True)
Similarly, the default computation device for JAX is set by environment variables (see here). To change this programmatically before calling susiepca add the following code:
import jax
platform = "gpu" # "gpu", "cpu", or "tpu"
jax.config.update("jax_platform_name", platform)
References¶
Support¶
Please report any bugs or feature requests in the Issue Tracker. If you have any questions or comments please contact dongyuan@usc.edu and/or nmancuso@usc.edu.
This project has been set up using PyScaffold 4.1.1. For details and usage information on PyScaffold see https://pyscaffold.org/.