Multi-Modal Surgery Pipeline with TOTALVI

[1]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
[2]:
import scanpy as sc
import anndata
import torch
import scarchest as sca
import matplotlib.pyplot as plt
import numpy as np
import scvi as scv
import pandas as pd
[3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

Data loading and preprocessing

For totalVI, we will treat two CITE-seq PBMC datasets from 10X Genomics as the reference. These datasets were already filtered for outliers like doublets, as described in the totalVI manuscript. There are 14 proteins in the reference.

[4]:
adata_ref = scv.data.pbmcs_10x_cite_seq(run_setup_anndata=False)
INFO      Downloading file at data/pbmc_10k_protein_v3.h5ad
Downloading...: 24938it [00:04, 5776.87it/s]
INFO      Downloading file at data/pbmc_5k_protein_v3.h5ad
Downloading...: 100%|████████████████████████████████████████████████████████| 18295/18295.0 [00:03<00:00, 5772.53it/s]
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
[5]:
adata_query = scv.data.dataset_10x("pbmc_10k_v3")
adata_query.obs["batch"] = "PBMC 10k (RNA only)"
# put matrix of zeros for protein expression (considered missing)
pro_exp = adata_ref.obsm["protein_expression"]
data = np.zeros((adata_query.n_obs, pro_exp.shape[1]))
adata_query.obsm["protein_expression"] = pd.DataFrame(columns=pro_exp.columns, index=adata_query.obs_names, data = data)
INFO      Downloading file at data/10X\pbmc_10k_v3\filtered_feature_bc_matrix.h5
Downloading...: 37492it [00:08, 4267.50it/s]
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.

Now to concatenate the objects, which intersects the genes properly.

[6]:
adata_full = anndata.concat([adata_ref, adata_query])
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.

And split them back up into reference and query (but now genes are properly aligned between objects).

[7]:
adata_ref = adata_full[np.logical_or(adata_full.obs.batch == "PBMC5k", adata_full.obs.batch == "PBMC10k")].copy()
adata_query = adata_full[adata_full.obs.batch == "PBMC 10k (RNA only)"].copy()
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.

We run gene selection on the reference, because that’s all that will be avaialble to us at first.

[8]:
sc.pp.highly_variable_genes(
    adata_ref,
    n_top_genes=4000,
    flavor="seurat_v3",
    batch_key="batch",
    subset=True,
)
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.

Finally, we use these selected genes for the query dataset as well.

[9]:
adata_query = adata_query[:, adata_ref.var_names].copy()

Create TOTALVI model and train it on CITE-seq reference dataset

[10]:
sca.dataset.setup_anndata(
    adata_ref,
    batch_key="batch",
    protein_expression_obsm_key="protein_expression"
)
INFO      Using batches from adata.obs["batch"]
INFO      No label_key inputted, assuming all cells have same label
INFO      Using data from adata.X
INFO      Computing library size prior per batch
INFO      Using protein expression from adata.obsm['protein_expression']
INFO      Using protein names from columns of adata.obsm['protein_expression']
INFO      Successfully registered anndata object containing 10849 cells, 4000 vars, 2
          batches, 1 labels, and 14 proteins. Also registered 0 extra categorical covariates
          and 0 extra continuous covariates.
INFO      Please do not further modify adata until model is trained.
[11]:
arches_params = dict(
    use_layer_norm="both",
    use_batch_norm="none",
)
vae_ref = sca.models.TOTALVI(
    adata_ref,
    use_cuda=True,
    **arches_params
)
[12]:
vae_ref.train()
INFO      Training for 400 epochs.
INFO      KL warmup for 8136.75 iterations
Training...:  67%|████████████████████████████████████████████▋                      | 267/400 [18:30<06:15,  2.83s/it]INFO      Reducing LR on epoch 267.
Training...:  76%|██████████████████████████████████████████████████▉                | 304/400 [20:05<03:48,  2.38s/it]INFO      Reducing LR on epoch 304.
Training...:  96%|███████████████████████████████████████████████████████████████▉   | 382/400 [23:37<00:48,  2.69s/it]INFO      Reducing LR on epoch 382.
Training...: 100%|███████████████████████████████████████████████████████████████████| 400/400 [24:21<00:00,  3.65s/it]
INFO      Training time:  1376 s. / 400 epochs

Save Latent representation and visualize RNA data

[13]:
adata_ref.obsm["X_totalVI"] = vae_ref.get_latent_representation()
sc.pp.neighbors(adata_ref, use_rep="X_totalVI")
sc.tl.umap(adata_ref, min_dist=0.4)
[14]:
sc.pl.umap(
    adata_ref,
    color=["batch"],
    frameon=False,
    ncols=1,
    title="Reference"
)
... storing 'batch' as categorical
_images/totalvi_surgery_pipeline_22_1.png

Save trained reference model

[15]:
dir_path = "saved_model/"
vae_ref.save(dir_path, overwrite=True)

Perform surgery on reference model and train on query dataset without protein data

[16]:
vae_q = sca.models.TOTALVI.load_query_data(
    adata_query,
    dir_path,
    freeze_expression=True
)
INFO      .obs[_scvi_labels] not found in target, assuming every cell is same category
INFO      Found batches with missing protein expression
INFO      Using data from adata.X
INFO      Computing library size prior per batch
INFO      Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels',
          'protein_expression']
INFO      Successfully registered anndata object containing 11769 cells, 4000 vars, 3
          batches, 1 labels, and 14 proteins. Also registered 0 extra categorical covariates
          and 0 extra continuous covariates.
[17]:
vae_q.train(200, weight_decay=0.0)
INFO      Training for 200 epochs.
INFO      KL warmup for 8826.75 iterations
Training...:  68%|█████████████████████████████████████████████▌                     | 136/200 [11:10<04:09,  3.90s/it]INFO      Reducing LR on epoch 136.
Training...:  76%|██████████████████████████████████████████████████▌                | 151/200 [12:03<03:00,  3.69s/it]INFO
          Stopping early: no improvement of more than 0 nats in 45 epochs
INFO      If the early stopping criterion is too strong, please instantiate it with different
          parameters in the train method.
Training...:  76%|██████████████████████████████████████████████████▌                | 151/200 [12:06<03:55,  4.81s/it]
INFO      Training is still in warming up phase. If your applications rely on the posterior
          quality, consider training for more epochs or reducing the kl warmup.
INFO      Training time:  689 s. / 200 epochs
[18]:
adata_query.obsm["X_totalVI"] = vae_q.get_latent_representation()
sc.pp.neighbors(adata_query, use_rep="X_totalVI")
sc.tl.umap(adata_query, min_dist=0.4)

Impute protein data for the query dataset and visualize

Impute the proteins that were observed in the reference, using the transform_batch parameter.

[19]:
_, imputed_proteins = vae_q.get_normalized_expression(
    adata_query,
    n_samples=25,
    return_mean=True,
    transform_batch=["PBMC10k", "PBMC5k"],
)
[20]:
adata_query.obs = pd.concat([adata_query.obs, imputed_proteins], axis=1)

sc.pl.umap(
    adata_query,
    color=imputed_proteins.columns,
    frameon=False,
    ncols=3,
)
... storing 'batch' as categorical
_images/totalvi_surgery_pipeline_32_1.png

Get latent representation of reference + query dataset and compute UMAP

[21]:
adata_full_new = adata_query.concatenate(adata_ref, batch_key="none")
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
[22]:
adata_full_new.obsm["X_totalVI"] = vae_q.get_latent_representation(adata_full_new)
sc.pp.neighbors(adata_full_new, use_rep="X_totalVI")
sc.tl.umap(adata_full_new, min_dist=0.3)
INFO      Input adata not setup with scvi. attempting to transfer anndata setup
INFO      Found batches with missing protein expression
INFO      Using data from adata.X
INFO      Computing library size prior per batch
INFO      Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels',
          'protein_expression']
INFO      Successfully registered anndata object containing 22618 cells, 4000 vars, 3
          batches, 1 labels, and 14 proteins. Also registered 0 extra categorical covariates
          and 0 extra continuous covariates.
[23]:
_, imputed_proteins_all = vae_q.get_normalized_expression(
    adata_full_new,
    n_samples=25,
    return_mean=True,
    transform_batch=["PBMC10k", "PBMC5k"],
)

for i, p in enumerate(imputed_proteins_all.columns):
    adata_full_new.obs[p] = imputed_proteins_all[p].to_numpy().copy()
[24]:
perm_inds = np.random.permutation(np.arange(adata_full_new.n_obs))
sc.pl.umap(
    adata_full_new[perm_inds],
    color=["batch"],
    frameon=False,
    ncols=1,
    title="Reference and query"
)
C:\Users\sergei.rybakov\Apps\Miniconda3\envs\work\lib\site-packages\anndata\_core\anndata.py:1213: ImplicitModificationWarning: Initializing view as actual.
  "Initializing view as actual.", ImplicitModificationWarning
Trying to set attribute `.obs` of view, copying.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
... storing 'batch' as categorical
_images/totalvi_surgery_pipeline_37_1.png
[25]:
ax = sc.pl.umap(
    adata_full_new,
    color="batch",
    groups=["PBMC 10k (RNA only)"],
    frameon=False,
    ncols=1,
    title="Reference and query",
    alpha=0.4
)
... storing 'batch' as categorical
_images/totalvi_surgery_pipeline_38_1.png
[26]:
sc.pl.umap(
    adata_full_new,
    color=imputed_proteins_all.columns,
    frameon=False,
    ncols=3,
    vmax="p99"
)
_images/totalvi_surgery_pipeline_39_0.png