Tutorial 1: simulate datasets

  • Import packages

[1]:
import torch
import scipy.sparse as sp
import scanpy as sc
from spatialign import Spatialign
from warnings import filterwarnings
from anndata import AnnData

filterwarnings("ignore")
[2]:
datasets = [
    "lfdb_simulate1.h5ad",
    "lfdb_simulate2.h5ad",
    "lfdb_simulate3.h5ad"
]
[3]:
def pca_lowrank(data, n_component=50, use_rep=None):
    data.uns["pca"] = {}
    if use_rep is None:
        x_tensor = torch.tensor(data.X.toarray()) if sp.issparse(data.X) else torch.tensor(data.X)
    else:
        assert use_rep in data.obsm_keys()
        x_tensor = torch.tensor(data.obsm[use_rep])
    u, s, v = torch.pca_lowrank(x_tensor, q=n_component)

    explained_variance_ = s.pow(2) / (data.shape[0] - 1)
    total_var = explained_variance_.sum()
    explained_variance_ratio_ = explained_variance_ / total_var
    data.obsm["X_pca"] = torch.matmul(x_tensor, v).numpy()
    data.uns["pca"]["variance"] = explained_variance_.numpy()
    data.uns["pca"]['variance_ratio'] = explained_variance_ratio_.numpy()
  • Initialize

    Read datasets and load Spatialign model

[4]:
model = Spatialign(
    *datasets,
    batch_key='batch',
    is_norm_log=True,
    is_scale=False,
    n_neigh=15,
    is_undirected=True,
    latent_dims=100,
    seed=42,
    gpu=0,
    save_path="./",
    is_verbose=False
)
2023-07-21 18:18:18 Found Dataset:
  cell nums: 4226 gene nums: 1999
  cell nums: 3661 gene nums: 1999
  cell nums: 3639 gene nums: 1999
2023-07-21 18:18:20: Spatial coordinates are used to calculate nearest neighbor graphs
2023-07-21 18:18:24 Method: 'init_bank' Running...
  Running time: 0.0 min 1.1580 s
  • Visualization original dataset by UMAP

[5]:
raw_merge = AnnData.concatenate(*model.dataset.data_list)
[6]:
pca_lowrank(raw_merge, n_component=100)
sc.pp.neighbors(raw_merge, random_state=42)
sc.tl.umap(raw_merge, random_state=42)
[7]:
sc.pl.umap(raw_merge, color=["batch", "celltype"])
_images/Tutorial-1-simulate-datasets_10_0.png
  • Training Spatialign model

    Spatialign aims to aligns the biological effects, while maximized preserving spots/cells biological variances in the latent embedding.

    It is possible to fine-tune the parameters (tau1, tau2, tau3) to optimize the model’s performance.

[8]:
model.train(tau1=0.05, tau2=0.01, tau3=0.1)
2023-07-21 18:22:57 Method: 'train' Running...
  2023-07-21 18:23:48 Epoch: 141 Loss: 33.6341 Loss min: 33.4821 EarlyStopping counter: 15 out of 15
  2023-07-21 18:23:48 Model Training Finished!
  2023-07-21 18:23:48 Trained checkpoint file has been saved to ./ckpt
  Running time: 0.0 min 51.2471 s
  • Inference alignment datasets

    After model training, the learned lower-dimensional representations will be saved in adata.obsm[‘correct’], and the higher-dimensional representations will be saved in adata.X

[9]:
model.alignment()
2023-07-21 18:23:48 Method: 'alignment' Running...
2023-07-21 18:23:50 Batch Alignment Finished!
2023-07-21 18:23:50 Alignment data saved in: ./res
  Running time: 0.0 min 1.7747 s
  • Validation inference datasets

[10]:
correct1 = sc.read_h5ad("./res/correct_data0.h5ad")
correct2 = sc.read_h5ad("./res/correct_data1.h5ad")
correct3 = sc.read_h5ad("./res/correct_data2.h5ad")
[11]:
merge_data = correct1.concatenate(correct2, correct3)
[12]:
sc.pp.neighbors(merge_data, use_rep="correct", random_state=42)
sc.tl.umap(merge_data, random_state=42)
[13]:
sc.pl.umap(merge_data, color=["batch", "celltype"])
_images/Tutorial-1-simulate-datasets_19_0.png
[ ]:

[ ]: