Skip to content

Adding new method: DRVI #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/methods/drvi/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# The API specifies which type of component this is.
# It contains specifications for:
# - The input/output files
# - Common parameters
# - A unit test
__merge__: ../../api/comp_method.yaml

# A unique identifier for your component (required).
# Can contain only lowercase letters or underscores.
name: drvi
# A relatively short label, used when rendering visualisations (required)
label: DRVI
# A one sentence summary of how this method works (required). Used when
# rendering summary tables.
summary: "DrVI is an unsupervised generative model capable of learning non-linear interpretable disentangled latent representations from single-cell count data."
# A multi-line description of how this component works (required). Used
# when rendering reference documentation.
description: |
Disentangled Representation Variational Inference (DRVI) is an unsupervised deep generative model designed for integrating single-cell RNA sequencing (scRNA-seq) data across different batches.
It extends the variational autoencoder (VAE) framework by learning a latent representation that captures biological variation while disentangling and correcting for batch effects.
DRVI conditions both the encoder and decoder on batch covariates, allowing it to explicitly model and mitigate batch-specific variations during training.
By incorporating a KL-divergence regularization term, it balances data reconstruction with latent space structure, resulting in a unified embedding where similar cells cluster together regardless of batch.
references:
doi:
- 10.1101/2024.11.06.622266
# bibtex:
# - |
# @article{foo,
# title={Foo},
# author={Bar},
# journal={Baz},
# year={2024}
# }
links:
# URL to the documentation for this method (required).
documentation: https://drvi.readthedocs.io/latest/index.html
# URL to the code repository for this method (required).
repository: https://github.com/theislab/DRVI?tab=readme-ov-file



# Metadata for your component
info:
# Which normalisation method this component prefers to use (required).
preferred_normalization: counts

# Component-specific parameters (optional)
# arguments:
# - name: "--n_neighbors"
# type: "integer"
# default: 5
# description: Number of neighbors to use.

# Resources required to run the component
resources:
# The script of your component (required)
- type: python_script
path: script.py
# Additional resources your script needs (optional)
# - type: file
# path: weights.pt

engines:
# Specifications for the Docker image for this component.
- type: docker
image: nvidia/cuda:12.3.2-runtime-ubuntu22.04
# Add custom dependencies here (optional). For more information, see
# https://viash.io/reference/config/engines/docker/#setup .
setup:
- type: python
pypi:
- drvi==0.0.10
packages: numpy<2

runners:
# This platform allows running the component natively
- type: executable
# Allows turning the component into a Nextflow module / pipeline.
- type: nextflow
directives:
label: [midtime,midmem,midcpu]
108 changes: 108 additions & 0 deletions src/methods/drvi/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import anndata as ad
import scanpy as sc
import drvi
from drvi.model import DRVI
from drvi.utils.misc import hvg_batch
import pandas as pd

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input': 'resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad',
'output': 'output.h5ad'
}
meta = {
'name': 'drvi'
}
## VIASH END

print('Reading input files', flush=True)
adata = ad.read_h5ad(par['input'])
# Remove dataset with non-count values
adata = adata[adata.obs["batch"] != "Villani"].copy()

print('Preprocess data', flush=True)
adata.X = adata.layers["counts"].copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
adata

sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
adata

# Batch aware HVG selection (method is obtained from scIB metrics)
hvg_genes = hvg_batch(adata, batch_key="batch", target_genes=2000, adataOut=False)
adata = adata[:, hvg_genes].copy()
adata

print('Train model with drVI', flush=True)
# Setup data
DRVI.setup_anndata(
adata,
# DRVI accepts count data by default.
# Do not forget to change gene_likelihood if you provide a non-count data.
layer="counts",
# Always provide a list. DRVI can accept multiple covariates.
categorical_covariate_keys=["batch"],
# DRVI accepts count data by default.
# Set to false if you provide log-normalized data and use normal distribution (mse loss).
is_count_data=False,
)

# construct the model
model = DRVI(
adata,
# Provide categorical covariates keys once again. Refer to advanced usages for more options.
categorical_covariates=["batch"],
n_latent=128,
# For encoder and decoder dims, provide a list of integers.
encoder_dims=[128, 128],
decoder_dims=[128, 128],
)
model

n_epochs = 400

# train the model
model.train(
max_epochs=n_epochs,
early_stopping=False,
early_stopping_patience=20,
# mps
# accelerator="mps", devices=1,
# cpu
# accelerator="cpu", devices=1,
# gpu: no additional parameter
#
# No need to provide `plan_kwargs` if n_epochs >= 400.
plan_kwargs={
"n_epochs_kl_warmup": n_epochs,
},
)

embed = ad.AnnData(model.get_latent_representation(), obs=adata.obs)
sc.pp.subsample(embed, fraction=1.0) # Shuffling for better visualization of overlapping colors

sc.pp.neighbors(embed, n_neighbors=10, use_rep="X", n_pcs=embed.X.shape[1])
sc.tl.umap(embed, spread=1.0, min_dist=0.5, random_state=123)
sc.pp.pca(embed)

print("Store outputs", flush=True)
output = ad.AnnData(
obs=adata.obs.copy(),
var=adata.var.copy(),
obsm={
"X_emb": model.get_latent_representation(),
},
uns={
"dataset_id": adata.uns.get("dataset_id", "unknown"),
"normalization_id": adata.uns.get("normalization_id", "unknown"),
"method_id": meta["name"],
},
)

print("Write output AnnData to file", flush=True)
output.write_h5ad(par['output'], compression='gzip')
Loading