Semi-supervised surgery pipeline with SCANVI¶
[1]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
[2]:
import scanpy as sc
import torch
import scarchest as sca
from scarchest.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown
[3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
Set relevant anndata.obs labels and training length¶
Here we use the CelSeq2 and SS2 studies as query data and the other 3 studies as reference atlas. We strongly suggest to use earlystopping to avoid over-fitting. The best earlystopping criteria are the ‘elbo’ for SCVI pretraining and also for unlabelled surgery training and ‘accurarcy’ for semi-supervised SCANVI training.
[4]:
condition_key = 'study'
cell_type_key = 'cell_type'
target_conditions = ['Pancreas CelSeq2', 'Pancreas SS2']
vae_epochs = 500
scanvi_epochs = 200
surgery_epochs = 500
early_stopping_kwargs = {
"early_stopping_metric": "elbo",
"save_best_state_metric": "elbo",
"patience": 10,
"threshold": 0,
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
early_stopping_kwargs_scanvi = {
"early_stopping_metric": "accuracy",
"save_best_state_metric": "accuracy",
"on": "full_dataset",
"patience": 10,
"threshold": 0.001,
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
early_stopping_kwargs_surgery = {
"early_stopping_metric": "elbo",
"save_best_state_metric": "elbo",
"on": "full_dataset",
"patience": 10,
"threshold": 0.001,
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
Download Dataset and split into reference dataset and query dataset¶
[5]:
url = 'https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd'
output = 'pancreas.h5ad'
gdown.download(url, output, quiet=False)
Downloading...
From: https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd
To: C:\Users\sergei.rybakov\projects\notebooks\pancreas.h5ad
126MB [00:29, 4.31MB/s]
[5]:
'pancreas.h5ad'
[6]:
adata_all = sc.read('pancreas.h5ad')
[7]:
adata = adata_all.raw.to_adata()
adata = remove_sparsity(adata)
source_adata = adata[~adata.obs[condition_key].isin(target_conditions)].copy()
target_adata = adata[adata.obs[condition_key].isin(target_conditions)].copy()
[8]:
source_adata
[8]:
AnnData object with n_obs × n_vars = 10294 × 1000
obs: 'batch', 'study', 'cell_type', 'size_factors'
[9]:
target_adata
[9]:
AnnData object with n_obs × n_vars = 5387 × 1000
obs: 'batch', 'study', 'cell_type', 'size_factors'
Create SCANVI model and train it on fully labelled reference dataset¶
[10]:
sca.dataset.setup_anndata(source_adata, batch_key=condition_key, labels_key=cell_type_key)
INFO Using batches from adata.obs["study"]
INFO Using labels from adata.obs["cell_type"]
INFO Using data from adata.X
INFO Computing library size prior per batch
INFO Successfully registered anndata object containing 10294 cells, 1000 vars, 3
batches, 8 labels, and 0 proteins. Also registered 0 extra categorical covariates
and 0 extra continuous covariates.
INFO Please do not further modify adata until model is trained.
The parameters chosen here proofed to work best in the case of surgery with SCANVI.
[11]:
vae = sca.models.SCANVI(
source_adata,
"Unknown",
n_layers=2,
encode_covariates=True,
deeply_inject_covariates=False,
use_layer_norm="both",
use_batch_norm="none",
)
[12]:
print("Labelled Indices: ", len(vae._labeled_indices))
print("Unlabelled Indices: ", len(vae._unlabeled_indices))
Labelled Indices: 10294
Unlabelled Indices: 0
[13]:
vae.train(
n_epochs_unsupervised=vae_epochs,
n_epochs_semisupervised=scanvi_epochs,
unsupervised_trainer_kwargs=dict(early_stopping_kwargs=early_stopping_kwargs),
semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo", "accuracy"],
early_stopping_kwargs=early_stopping_kwargs_scanvi),
frequency=1
)
INFO Training Unsupervised Trainer for 500 epochs.
INFO Training SemiSupervised Trainer for 200 epochs.
INFO KL warmup for 400 epochs
Training...: 20%|█████████████▍ | 99/500 [04:03<18:22, 2.75s/it]INFO Reducing LR on epoch 99.
Training...: 25%|████████████████▊ | 125/500 [05:15<17:12, 2.75s/it]INFO Reducing LR on epoch 125.
Training...: 25%|█████████████████ | 127/500 [05:20<17:06, 2.75s/it]INFO
Stopping early: no improvement of more than 0 nats in 10 epochs
INFO If the early stopping criterion is too strong, please instantiate it with different
parameters in the train method.
Training...: 25%|█████████████████ | 127/500 [05:23<15:50, 2.55s/it]
INFO Training is still in warming up phase. If your applications rely on the posterior
quality, consider training for more epochs or reducing the kl warmup.
INFO Training time: 214 s. / 500 epochs
INFO KL warmup phase exceeds overall training phaseIf your applications rely on the
posterior quality, consider training for more epochs or reducing the kl warmup.
INFO KL warmup for 400 epochs
Training...: 19%|████████████▉ | 38/200 [05:51<25:02, 9.28s/it]INFO Reducing LR on epoch 38.
Training...: 20%|█████████████▌ | 40/200 [06:10<24:43, 9.27s/it]INFO
Stopping early: no improvement of more than 0.001 nats in 10 epochs
INFO If the early stopping criterion is too strong, please instantiate it with different
parameters in the train method.
Training...: 20%|█████████████▌ | 40/200 [06:19<25:18, 9.49s/it]
INFO Training is still in warming up phase. If your applications rely on the posterior
quality, consider training for more epochs or reducing the kl warmup.
INFO Training time: 228 s. / 200 epochs
Create anndata file of latent representation and compute UMAP¶
[14]:
reference_latent = sc.AnnData(vae.get_latent_representation())
reference_latent.obs["cell_type"] = source_adata.obs[cell_type_key].tolist()
reference_latent.obs["batch"] = source_adata.obs[condition_key].tolist()
[15]:
sc.pp.neighbors(reference_latent, n_neighbors=8)
sc.tl.leiden(reference_latent)
sc.tl.umap(reference_latent)
sc.pl.umap(reference_latent,
color=['batch', 'cell_type'],
frameon=False,
wspace=0.6,
)
... storing 'cell_type' as categorical
... storing 'batch' as categorical
One can also compute the accuracy of the learned classifier
[16]:
reference_latent.obs['predictions'] = vae.predict()
print("Acc: {}".format(np.mean(reference_latent.obs.predictions == reference_latent.obs.cell_type)))
Acc: 0.9619195647950263
After pretraining the model can be saved for later use
[17]:
ref_path = 'ref_model/'
vae.save(ref_path, overwrite=True)
Perform surgery on reference model and train on query dataset without cell type labels¶
[18]:
model = sca.models.SCANVI.load_query_data(
target_adata,
ref_path,
freeze_dropout = True,
)
model._unlabeled_indices = np.arange(target_adata.n_obs)
model._labeled_indices = []
print("Labelled Indices: ", len(model._labeled_indices))
print("Unlabelled Indices: ", len(model._unlabeled_indices))
INFO Using data from adata.X
INFO Computing library size prior per batch
INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels']
INFO Successfully registered anndata object containing 5387 cells, 1000 vars, 5 batches,
8 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0
extra continuous covariates.
WARNING Make sure the registered X field in anndata contains unnormalized count data.
Labelled Indices: 0
Unlabelled Indices: 5387
[19]:
model.train(
n_epochs_semisupervised=surgery_epochs,
train_base_model=False,
semisupervised_trainer_kwargs=dict(metrics_to_monitor=["accuracy", "elbo"],
weight_decay=0,
early_stopping_kwargs=early_stopping_kwargs_surgery
),
frequency=1
)
INFO Training Unsupervised Trainer for 400 epochs.
INFO Training SemiSupervised Trainer for 500 epochs.
INFO KL warmup for 400 epochs
Training...: 22%|██████████████▊ | 111/500 [07:08<25:03, 3.87s/it]INFO Reducing LR on epoch 111.
Training...: 23%|███████████████▏ | 113/500 [07:16<24:55, 3.86s/it]INFO
Stopping early: no improvement of more than 0.001 nats in 10 epochs
INFO If the early stopping criterion is too strong, please instantiate it with different
parameters in the train method.
Training...: 23%|███████████████▏ | 113/500 [07:20<25:08, 3.90s/it]
INFO Training is still in warming up phase. If your applications rely on the posterior
quality, consider training for more epochs or reducing the kl warmup.
INFO Training time: 217 s. / 500 epochs
[20]:
query_latent = sc.AnnData(model.get_latent_representation())
query_latent.obs['cell_type'] = target_adata.obs[cell_type_key].tolist()
query_latent.obs['batch'] = target_adata.obs[condition_key].tolist()
WARNING Make sure the registered X field in anndata contains unnormalized count data.
[21]:
sc.pp.neighbors(query_latent)
sc.tl.leiden(query_latent)
sc.tl.umap(query_latent)
plt.figure()
sc.pl.umap(
query_latent,
color=["batch", "cell_type"],
frameon=False,
wspace=0.6,
)
... storing 'cell_type' as categorical
... storing 'batch' as categorical
<Figure size 320x320 with 0 Axes>
[22]:
surgery_path = 'surgery_model'
model.save(surgery_path, overwrite=True)
Compute Accuracy of model classifier for query dataset and compare predicted and observed cell types¶
[23]:
query_latent.obs['predictions'] = model.predict()
print("Acc: {}".format(np.mean(query_latent.obs.predictions == query_latent.obs.cell_type)))
WARNING Make sure the registered X field in anndata contains unnormalized count data.
Acc: 0.8791535177278633
[24]:
df = query_latent.obs.groupby(["cell_type", "predictions"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)
plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
[24]:
Text(0, 0.5, 'Observed')
Get latent representation of reference + query dataset and compute UMAP¶
[25]:
adata_full = source_adata.concatenate(target_adata)
full_latent = sc.AnnData(model.get_latent_representation(adata=adata_full))
full_latent.obs['cell_type'] = adata_full.obs[cell_type_key].tolist()
full_latent.obs['batch'] = adata_full.obs[condition_key].tolist()
INFO Input adata not setup with scvi. attempting to transfer anndata setup
INFO Using data from adata.X
INFO Computing library size prior per batch
INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels']
INFO Successfully registered anndata object containing 15681 cells, 1000 vars, 5
batches, 8 labels, and 0 proteins. Also registered 0 extra categorical covariates
and 0 extra continuous covariates.
[26]:
sc.pp.neighbors(full_latent)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)
plt.figure()
sc.pl.umap(
full_latent,
color=["batch", "cell_type"],
frameon=False,
wspace=0.6,
)
... storing 'cell_type' as categorical
... storing 'batch' as categorical
<Figure size 320x320 with 0 Axes>
Comparison of observed and predicted celltypes for reference + query dataset¶
[27]:
full_latent.obs['predictions'] = model.predict(adata=adata_full)
print("Acc: {}".format(np.mean(full_latent.obs.predictions == full_latent.obs.cell_type)))
Acc: 0.933486384796888
[28]:
sc.pp.neighbors(full_latent)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)
plt.figure()
sc.pl.umap(
full_latent,
color=["predictions", "cell_type"],
frameon=False,
wspace=0.6,
)
... storing 'predictions' as categorical
<Figure size 320x320 with 0 Axes>