Semi-supervised surgery pipeline with SCANVI

[1]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
[2]:
import scanpy as sc
import torch
import scarchest as sca
from scarchest.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown
[3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

Set relevant anndata.obs labels and training length

Here we use the CelSeq2 and SS2 studies as query data and the other 3 studies as reference atlas. We strongly suggest to use earlystopping to avoid over-fitting. The best earlystopping criteria are the ‘elbo’ for SCVI pretraining and also for unlabelled surgery training and ‘accurarcy’ for semi-supervised SCANVI training.

[4]:
condition_key = 'study'
cell_type_key = 'cell_type'
target_conditions = ['Pancreas CelSeq2', 'Pancreas SS2']

vae_epochs = 500
scanvi_epochs = 200
surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
early_stopping_kwargs_scanvi = {
    "early_stopping_metric": "accuracy",
    "save_best_state_metric": "accuracy",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
early_stopping_kwargs_surgery = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

Download Dataset and split into reference dataset and query dataset

[5]:
url = 'https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd'
output = 'pancreas.h5ad'
gdown.download(url, output, quiet=False)
Downloading...
From: https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd
To: C:\Users\sergei.rybakov\projects\notebooks\pancreas.h5ad
126MB [00:29, 4.31MB/s]
[5]:
'pancreas.h5ad'
[6]:
adata_all = sc.read('pancreas.h5ad')
[7]:
adata = adata_all.raw.to_adata()
adata = remove_sparsity(adata)
source_adata = adata[~adata.obs[condition_key].isin(target_conditions)].copy()
target_adata = adata[adata.obs[condition_key].isin(target_conditions)].copy()
[8]:
source_adata
[8]:
AnnData object with n_obs × n_vars = 10294 × 1000
    obs: 'batch', 'study', 'cell_type', 'size_factors'
[9]:
target_adata
[9]:
AnnData object with n_obs × n_vars = 5387 × 1000
    obs: 'batch', 'study', 'cell_type', 'size_factors'

Create SCANVI model and train it on fully labelled reference dataset

[10]:
sca.dataset.setup_anndata(source_adata, batch_key=condition_key, labels_key=cell_type_key)
INFO      Using batches from adata.obs["study"]
INFO      Using labels from adata.obs["cell_type"]
INFO      Using data from adata.X
INFO      Computing library size prior per batch
INFO      Successfully registered anndata object containing 10294 cells, 1000 vars, 3
          batches, 8 labels, and 0 proteins. Also registered 0 extra categorical covariates
          and 0 extra continuous covariates.
INFO      Please do not further modify adata until model is trained.

The parameters chosen here proofed to work best in the case of surgery with SCANVI.

[11]:
vae = sca.models.SCANVI(
    source_adata,
    "Unknown",
    n_layers=2,
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
)
[12]:
print("Labelled Indices: ", len(vae._labeled_indices))
print("Unlabelled Indices: ", len(vae._unlabeled_indices))
Labelled Indices:  10294
Unlabelled Indices:  0
[13]:
vae.train(
    n_epochs_unsupervised=vae_epochs,
    n_epochs_semisupervised=scanvi_epochs,
    unsupervised_trainer_kwargs=dict(early_stopping_kwargs=early_stopping_kwargs),
    semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo", "accuracy"],
                                       early_stopping_kwargs=early_stopping_kwargs_scanvi),
    frequency=1
)
INFO      Training Unsupervised Trainer for 500 epochs.
INFO      Training SemiSupervised Trainer for 200 epochs.
INFO      KL warmup for 400 epochs
Training...:  20%|█████████████▍                                                      | 99/500 [04:03<18:22,  2.75s/it]INFO      Reducing LR on epoch 99.
Training...:  25%|████████████████▊                                                  | 125/500 [05:15<17:12,  2.75s/it]INFO      Reducing LR on epoch 125.
Training...:  25%|█████████████████                                                  | 127/500 [05:20<17:06,  2.75s/it]INFO
          Stopping early: no improvement of more than 0 nats in 10 epochs
INFO      If the early stopping criterion is too strong, please instantiate it with different
          parameters in the train method.
Training...:  25%|█████████████████                                                  | 127/500 [05:23<15:50,  2.55s/it]
INFO      Training is still in warming up phase. If your applications rely on the posterior
          quality, consider training for more epochs or reducing the kl warmup.
INFO      Training time:  214 s. / 500 epochs
INFO      KL warmup phase exceeds overall training phaseIf your applications rely on the
          posterior quality, consider training for more epochs or reducing the kl warmup.
INFO      KL warmup for 400 epochs
Training...:  19%|████████████▉                                                       | 38/200 [05:51<25:02,  9.28s/it]INFO      Reducing LR on epoch 38.
Training...:  20%|█████████████▌                                                      | 40/200 [06:10<24:43,  9.27s/it]INFO
          Stopping early: no improvement of more than 0.001 nats in 10 epochs
INFO      If the early stopping criterion is too strong, please instantiate it with different
          parameters in the train method.
Training...:  20%|█████████████▌                                                      | 40/200 [06:19<25:18,  9.49s/it]
INFO      Training is still in warming up phase. If your applications rely on the posterior
          quality, consider training for more epochs or reducing the kl warmup.
INFO      Training time:  228 s. / 200 epochs

Create anndata file of latent representation and compute UMAP

[14]:
reference_latent = sc.AnnData(vae.get_latent_representation())
reference_latent.obs["cell_type"] = source_adata.obs[cell_type_key].tolist()
reference_latent.obs["batch"] = source_adata.obs[condition_key].tolist()
[15]:
sc.pp.neighbors(reference_latent, n_neighbors=8)
sc.tl.leiden(reference_latent)
sc.tl.umap(reference_latent)
sc.pl.umap(reference_latent,
           color=['batch', 'cell_type'],
           frameon=False,
           wspace=0.6,
           )
... storing 'cell_type' as categorical
... storing 'batch' as categorical
_images/scanvi_surgery_pipeline_20_1.png

One can also compute the accuracy of the learned classifier

[16]:
reference_latent.obs['predictions'] = vae.predict()
print("Acc: {}".format(np.mean(reference_latent.obs.predictions == reference_latent.obs.cell_type)))
Acc: 0.9619195647950263

After pretraining the model can be saved for later use

[17]:
ref_path = 'ref_model/'
vae.save(ref_path, overwrite=True)

Perform surgery on reference model and train on query dataset without cell type labels

[18]:
model = sca.models.SCANVI.load_query_data(
    target_adata,
    ref_path,
    freeze_dropout = True,
)
model._unlabeled_indices = np.arange(target_adata.n_obs)
model._labeled_indices = []
print("Labelled Indices: ", len(model._labeled_indices))
print("Unlabelled Indices: ", len(model._unlabeled_indices))
INFO      Using data from adata.X
INFO      Computing library size prior per batch
INFO      Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels']
INFO      Successfully registered anndata object containing 5387 cells, 1000 vars, 5 batches,
          8 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0
          extra continuous covariates.
WARNING   Make sure the registered X field in anndata contains unnormalized count data.
Labelled Indices:  0
Unlabelled Indices:  5387
[19]:
model.train(
    n_epochs_semisupervised=surgery_epochs,
    train_base_model=False,
    semisupervised_trainer_kwargs=dict(metrics_to_monitor=["accuracy", "elbo"],
                                       weight_decay=0,
                                       early_stopping_kwargs=early_stopping_kwargs_surgery
                                      ),
    frequency=1
)
INFO      Training Unsupervised Trainer for 400 epochs.
INFO      Training SemiSupervised Trainer for 500 epochs.
INFO      KL warmup for 400 epochs
Training...:  22%|██████████████▊                                                    | 111/500 [07:08<25:03,  3.87s/it]INFO      Reducing LR on epoch 111.
Training...:  23%|███████████████▏                                                   | 113/500 [07:16<24:55,  3.86s/it]INFO
          Stopping early: no improvement of more than 0.001 nats in 10 epochs
INFO      If the early stopping criterion is too strong, please instantiate it with different
          parameters in the train method.
Training...:  23%|███████████████▏                                                   | 113/500 [07:20<25:08,  3.90s/it]
INFO      Training is still in warming up phase. If your applications rely on the posterior
          quality, consider training for more epochs or reducing the kl warmup.
INFO      Training time:  217 s. / 500 epochs
[20]:
query_latent = sc.AnnData(model.get_latent_representation())
query_latent.obs['cell_type'] = target_adata.obs[cell_type_key].tolist()
query_latent.obs['batch'] = target_adata.obs[condition_key].tolist()
WARNING   Make sure the registered X field in anndata contains unnormalized count data.
[21]:
sc.pp.neighbors(query_latent)
sc.tl.leiden(query_latent)
sc.tl.umap(query_latent)
plt.figure()
sc.pl.umap(
    query_latent,
    color=["batch", "cell_type"],
    frameon=False,
    wspace=0.6,
)
... storing 'cell_type' as categorical
... storing 'batch' as categorical
<Figure size 320x320 with 0 Axes>
_images/scanvi_surgery_pipeline_29_2.png
[22]:
surgery_path = 'surgery_model'
model.save(surgery_path, overwrite=True)

Compute Accuracy of model classifier for query dataset and compare predicted and observed cell types

[23]:
query_latent.obs['predictions'] = model.predict()
print("Acc: {}".format(np.mean(query_latent.obs.predictions == query_latent.obs.cell_type)))
WARNING   Make sure the registered X field in anndata contains unnormalized count data.
Acc: 0.8791535177278633
[24]:
df = query_latent.obs.groupby(["cell_type", "predictions"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)

plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
[24]:
Text(0, 0.5, 'Observed')
_images/scanvi_surgery_pipeline_33_1.png

Get latent representation of reference + query dataset and compute UMAP

[25]:
adata_full = source_adata.concatenate(target_adata)
full_latent = sc.AnnData(model.get_latent_representation(adata=adata_full))
full_latent.obs['cell_type'] = adata_full.obs[cell_type_key].tolist()
full_latent.obs['batch'] = adata_full.obs[condition_key].tolist()
INFO      Input adata not setup with scvi. attempting to transfer anndata setup
INFO      Using data from adata.X
INFO      Computing library size prior per batch
INFO      Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels']
INFO      Successfully registered anndata object containing 15681 cells, 1000 vars, 5
          batches, 8 labels, and 0 proteins. Also registered 0 extra categorical covariates
          and 0 extra continuous covariates.
[26]:
sc.pp.neighbors(full_latent)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)
plt.figure()
sc.pl.umap(
    full_latent,
    color=["batch", "cell_type"],
    frameon=False,
    wspace=0.6,
)
... storing 'cell_type' as categorical
... storing 'batch' as categorical
<Figure size 320x320 with 0 Axes>
_images/scanvi_surgery_pipeline_36_2.png

Comparison of observed and predicted celltypes for reference + query dataset

[27]:
full_latent.obs['predictions'] = model.predict(adata=adata_full)
print("Acc: {}".format(np.mean(full_latent.obs.predictions == full_latent.obs.cell_type)))
Acc: 0.933486384796888
[28]:
sc.pp.neighbors(full_latent)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)
plt.figure()
sc.pl.umap(
    full_latent,
    color=["predictions", "cell_type"],
    frameon=False,
    wspace=0.6,
)
... storing 'predictions' as categorical
<Figure size 320x320 with 0 Axes>
_images/scanvi_surgery_pipeline_39_2.png