Multi-Modal Surgery Pipeline with TOTALVI¶
[1]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
[2]:
import scanpy as sc
import anndata
import torch
import scarchest as sca
import matplotlib.pyplot as plt
import numpy as np
import scvi as scv
import pandas as pd
[3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
Data loading and preprocessing¶
For totalVI, we will treat two CITE-seq PBMC datasets from 10X Genomics as the reference. These datasets were already filtered for outliers like doublets, as described in the totalVI manuscript. There are 14 proteins in the reference.
[4]:
adata_ref = scv.data.pbmcs_10x_cite_seq(run_setup_anndata=False)
INFO Downloading file at data/pbmc_10k_protein_v3.h5ad
Downloading...: 24938it [00:04, 5776.87it/s]
INFO Downloading file at data/pbmc_5k_protein_v3.h5ad
Downloading...: 100%|████████████████████████████████████████████████████████| 18295/18295.0 [00:03<00:00, 5772.53it/s]
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
[5]:
adata_query = scv.data.dataset_10x("pbmc_10k_v3")
adata_query.obs["batch"] = "PBMC 10k (RNA only)"
# put matrix of zeros for protein expression (considered missing)
pro_exp = adata_ref.obsm["protein_expression"]
data = np.zeros((adata_query.n_obs, pro_exp.shape[1]))
adata_query.obsm["protein_expression"] = pd.DataFrame(columns=pro_exp.columns, index=adata_query.obs_names, data = data)
INFO Downloading file at data/10X\pbmc_10k_v3\filtered_feature_bc_matrix.h5
Downloading...: 37492it [00:08, 4267.50it/s]
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Now to concatenate the objects, which intersects the genes properly.
[6]:
adata_full = anndata.concat([adata_ref, adata_query])
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
And split them back up into reference and query (but now genes are properly aligned between objects).
[7]:
adata_ref = adata_full[np.logical_or(adata_full.obs.batch == "PBMC5k", adata_full.obs.batch == "PBMC10k")].copy()
adata_query = adata_full[adata_full.obs.batch == "PBMC 10k (RNA only)"].copy()
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
We run gene selection on the reference, because that’s all that will be avaialble to us at first.
[8]:
sc.pp.highly_variable_genes(
adata_ref,
n_top_genes=4000,
flavor="seurat_v3",
batch_key="batch",
subset=True,
)
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Finally, we use these selected genes for the query dataset as well.
[9]:
adata_query = adata_query[:, adata_ref.var_names].copy()
Create TOTALVI model and train it on CITE-seq reference dataset¶
[10]:
sca.dataset.setup_anndata(
adata_ref,
batch_key="batch",
protein_expression_obsm_key="protein_expression"
)
INFO Using batches from adata.obs["batch"]
INFO No label_key inputted, assuming all cells have same label
INFO Using data from adata.X
INFO Computing library size prior per batch
INFO Using protein expression from adata.obsm['protein_expression']
INFO Using protein names from columns of adata.obsm['protein_expression']
INFO Successfully registered anndata object containing 10849 cells, 4000 vars, 2
batches, 1 labels, and 14 proteins. Also registered 0 extra categorical covariates
and 0 extra continuous covariates.
INFO Please do not further modify adata until model is trained.
[11]:
arches_params = dict(
use_layer_norm="both",
use_batch_norm="none",
)
vae_ref = sca.models.TOTALVI(
adata_ref,
use_cuda=True,
**arches_params
)
[12]:
vae_ref.train()
INFO Training for 400 epochs.
INFO KL warmup for 8136.75 iterations
Training...: 67%|████████████████████████████████████████████▋ | 267/400 [18:30<06:15, 2.83s/it]INFO Reducing LR on epoch 267.
Training...: 76%|██████████████████████████████████████████████████▉ | 304/400 [20:05<03:48, 2.38s/it]INFO Reducing LR on epoch 304.
Training...: 96%|███████████████████████████████████████████████████████████████▉ | 382/400 [23:37<00:48, 2.69s/it]INFO Reducing LR on epoch 382.
Training...: 100%|███████████████████████████████████████████████████████████████████| 400/400 [24:21<00:00, 3.65s/it]
INFO Training time: 1376 s. / 400 epochs
Save Latent representation and visualize RNA data¶
[13]:
adata_ref.obsm["X_totalVI"] = vae_ref.get_latent_representation()
sc.pp.neighbors(adata_ref, use_rep="X_totalVI")
sc.tl.umap(adata_ref, min_dist=0.4)
[14]:
sc.pl.umap(
adata_ref,
color=["batch"],
frameon=False,
ncols=1,
title="Reference"
)
... storing 'batch' as categorical
Save trained reference model¶
[15]:
dir_path = "saved_model/"
vae_ref.save(dir_path, overwrite=True)
Perform surgery on reference model and train on query dataset without protein data¶
[16]:
vae_q = sca.models.TOTALVI.load_query_data(
adata_query,
dir_path,
freeze_expression=True
)
INFO .obs[_scvi_labels] not found in target, assuming every cell is same category
INFO Found batches with missing protein expression
INFO Using data from adata.X
INFO Computing library size prior per batch
INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels',
'protein_expression']
INFO Successfully registered anndata object containing 11769 cells, 4000 vars, 3
batches, 1 labels, and 14 proteins. Also registered 0 extra categorical covariates
and 0 extra continuous covariates.
[17]:
vae_q.train(200, weight_decay=0.0)
INFO Training for 200 epochs.
INFO KL warmup for 8826.75 iterations
Training...: 68%|█████████████████████████████████████████████▌ | 136/200 [11:10<04:09, 3.90s/it]INFO Reducing LR on epoch 136.
Training...: 76%|██████████████████████████████████████████████████▌ | 151/200 [12:03<03:00, 3.69s/it]INFO
Stopping early: no improvement of more than 0 nats in 45 epochs
INFO If the early stopping criterion is too strong, please instantiate it with different
parameters in the train method.
Training...: 76%|██████████████████████████████████████████████████▌ | 151/200 [12:06<03:55, 4.81s/it]
INFO Training is still in warming up phase. If your applications rely on the posterior
quality, consider training for more epochs or reducing the kl warmup.
INFO Training time: 689 s. / 200 epochs
[18]:
adata_query.obsm["X_totalVI"] = vae_q.get_latent_representation()
sc.pp.neighbors(adata_query, use_rep="X_totalVI")
sc.tl.umap(adata_query, min_dist=0.4)
Impute protein data for the query dataset and visualize¶
Impute the proteins that were observed in the reference, using the transform_batch
parameter.
[19]:
_, imputed_proteins = vae_q.get_normalized_expression(
adata_query,
n_samples=25,
return_mean=True,
transform_batch=["PBMC10k", "PBMC5k"],
)
[20]:
adata_query.obs = pd.concat([adata_query.obs, imputed_proteins], axis=1)
sc.pl.umap(
adata_query,
color=imputed_proteins.columns,
frameon=False,
ncols=3,
)
... storing 'batch' as categorical
Get latent representation of reference + query dataset and compute UMAP¶
[21]:
adata_full_new = adata_query.concatenate(adata_ref, batch_key="none")
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
[22]:
adata_full_new.obsm["X_totalVI"] = vae_q.get_latent_representation(adata_full_new)
sc.pp.neighbors(adata_full_new, use_rep="X_totalVI")
sc.tl.umap(adata_full_new, min_dist=0.3)
INFO Input adata not setup with scvi. attempting to transfer anndata setup
INFO Found batches with missing protein expression
INFO Using data from adata.X
INFO Computing library size prior per batch
INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels',
'protein_expression']
INFO Successfully registered anndata object containing 22618 cells, 4000 vars, 3
batches, 1 labels, and 14 proteins. Also registered 0 extra categorical covariates
and 0 extra continuous covariates.
[23]:
_, imputed_proteins_all = vae_q.get_normalized_expression(
adata_full_new,
n_samples=25,
return_mean=True,
transform_batch=["PBMC10k", "PBMC5k"],
)
for i, p in enumerate(imputed_proteins_all.columns):
adata_full_new.obs[p] = imputed_proteins_all[p].to_numpy().copy()
[24]:
perm_inds = np.random.permutation(np.arange(adata_full_new.n_obs))
sc.pl.umap(
adata_full_new[perm_inds],
color=["batch"],
frameon=False,
ncols=1,
title="Reference and query"
)
C:\Users\sergei.rybakov\Apps\Miniconda3\envs\work\lib\site-packages\anndata\_core\anndata.py:1213: ImplicitModificationWarning: Initializing view as actual.
"Initializing view as actual.", ImplicitModificationWarning
Trying to set attribute `.obs` of view, copying.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
... storing 'batch' as categorical
[25]:
ax = sc.pl.umap(
adata_full_new,
color="batch",
groups=["PBMC 10k (RNA only)"],
frameon=False,
ncols=1,
title="Reference and query",
alpha=0.4
)
... storing 'batch' as categorical
[26]:
sc.pl.umap(
adata_full_new,
color=imputed_proteins_all.columns,
frameon=False,
ncols=3,
vmax="p99"
)