From 893bab812515ca55bdb3fd286167697ca76628e7 Mon Sep 17 00:00:00 2001 From: seohyonkim Date: Mon, 2 Jun 2025 15:30:06 +0200 Subject: [PATCH 1/5] script for drvi --- src/methods/drvi/config.vsh.yaml | 79 ++++++++++++++++++++++ src/methods/drvi/script.py | 108 +++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 src/methods/drvi/config.vsh.yaml create mode 100644 src/methods/drvi/script.py diff --git a/src/methods/drvi/config.vsh.yaml b/src/methods/drvi/config.vsh.yaml new file mode 100644 index 00000000..e3cb78aa --- /dev/null +++ b/src/methods/drvi/config.vsh.yaml @@ -0,0 +1,79 @@ +# The API specifies which type of component this is. +# It contains specifications for: +# - The input/output files +# - Common parameters +# - A unit test +__merge__: ../../api/comp_method.yaml + +# A unique identifier for your component (required). +# Can contain only lowercase letters or underscores. +name: drvi +# A relatively short label, used when rendering visualisations (required) +label: DRVI +# A one sentence summary of how this method works (required). Used when +# rendering summary tables. +summary: "DrVI is an unsupervised generative model capable of learning non-linear interpretable disentangled latent representations from single-cell count data." +# A multi-line description of how this component works (required). Used +# when rendering reference documentation. +description: | + Disentangled Representation Variational Inference (DRVI) is an unsupervised deep generative model designed for integrating single-cell RNA sequencing (scRNA-seq) data across different batches. + It extends the variational autoencoder (VAE) framework by learning a latent representation that captures biological variation while disentangling and correcting for batch effects. + DRVI conditions both the encoder and decoder on batch covariates, allowing it to explicitly model and mitigate batch-specific variations during training. + By incorporating a KL-divergence regularization term, it balances data reconstruction with latent space structure, resulting in a unified embedding where similar cells cluster together regardless of batch. +references: + doi: + - 10.1101/2024.11.06.622266 +# bibtex: +# - | +# @article{foo, +# title={Foo}, +# author={Bar}, +# journal={Baz}, +# year={2024} +# } +links: + # URL to the documentation for this method (required). + documentation: https://drvi.readthedocs.io/latest/index.html + # URL to the code repository for this method (required). + repository: https://github.com/theislab/DRVI?tab=readme-ov-file + + + +# Metadata for your component +info: + # Which normalisation method this component prefers to use (required). + preferred_normalization: counts + +# Component-specific parameters (optional) +# arguments: +# - name: "--n_neighbors" +# type: "integer" +# default: 5 +# description: Number of neighbors to use. + +# Resources required to run the component +resources: + # The script of your component (required) + - type: python_script + path: script.py + # Additional resources your script needs (optional) + # - type: file + # path: weights.pt + +engines: + # Specifications for the Docker image for this component. + - type: docker + image: openproblems/base_python:1.0.0 + # Add custom dependencies here (optional). For more information, see + # https://viash.io/reference/config/engines/docker/#setup . + # setup: + # - type: python + # packages: numpy<2 + +runners: + # This platform allows running the component natively + - type: executable + # Allows turning the component into a Nextflow module / pipeline. + - type: nextflow + directives: + label: [midtime,midmem,midcpu] diff --git a/src/methods/drvi/script.py b/src/methods/drvi/script.py new file mode 100644 index 00000000..068c426e --- /dev/null +++ b/src/methods/drvi/script.py @@ -0,0 +1,108 @@ +import anndata as ad +import scanpy as sc +import drvi +from drvi.model import DRVI +from drvi.utils.misc import hvg_batch +import pandas as pd + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + 'input': 'resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad', + 'output': 'output.h5ad' +} +meta = { + 'name': 'drvi' +} +## VIASH END + +print('Reading input files', flush=True) +adata = ad.read_h5ad(par['input']) +# Remove dataset with non-count values +adata = adata[adata.obs["batch"] != "Villani"].copy() + +print('Preprocess data', flush=True) +adata.X = adata.layers["counts"].copy() +sc.pp.normalize_total(adata) +sc.pp.log1p(adata) +adata + +sc.pp.pca(adata) +sc.pp.neighbors(adata) +sc.tl.umap(adata) +adata + +# Batch aware HVG selection (method is obtained from scIB metrics) +hvg_genes = hvg_batch(adata, batch_key="batch", target_genes=2000, adataOut=False) +adata = adata[:, hvg_genes].copy() +adata + +print('Train model with drVI', flush=True) +# Setup data +DRVI.setup_anndata( + adata, + # DRVI accepts count data by default. + # Do not forget to change gene_likelihood if you provide a non-count data. + layer="counts", + # Always provide a list. DRVI can accept multiple covariates. + categorical_covariate_keys=["batch"], + # DRVI accepts count data by default. + # Set to false if you provide log-normalized data and use normal distribution (mse loss). + is_count_data=False, +) + +# construct the model +model = DRVI( + adata, + # Provide categorical covariates keys once again. Refer to advanced usages for more options. + categorical_covariates=["batch"], + n_latent=128, + # For encoder and decoder dims, provide a list of integers. + encoder_dims=[128, 128], + decoder_dims=[128, 128], +) +model + +n_epochs = 400 + +# train the model +model.train( + max_epochs=n_epochs, + early_stopping=False, + early_stopping_patience=20, + # mps + # accelerator="mps", devices=1, + # cpu + # accelerator="cpu", devices=1, + # gpu: no additional parameter + # + # No need to provide `plan_kwargs` if n_epochs >= 400. + plan_kwargs={ + "n_epochs_kl_warmup": n_epochs, + }, +) + +embed = ad.AnnData(model.get_latent_representation(), obs=adata.obs) +sc.pp.subsample(embed, fraction=1.0) # Shuffling for better visualization of overlapping colors + +sc.pp.neighbors(embed, n_neighbors=10, use_rep="X", n_pcs=embed.X.shape[1]) +sc.tl.umap(embed, spread=1.0, min_dist=0.5, random_state=123) +sc.pp.pca(embed) + +print("Store outputs", flush=True) +output = ad.AnnData( + obs=adata.obs.copy(), + var=adata.var.copy(), + obsm={ + "X_emb": model.get_latent_representation(), + }, + uns={ + "dataset_id": adata.uns.get("dataset_id", "unknown"), + "normalization_id": adata.uns.get("normalization_id", "unknown"), + "method_id": meta["name"], + }, +) + +print("Write output AnnData to file", flush=True) +output.write_h5ad(par['output'], compression='gzip') From 54c050de1b52a41ab9ba2256567382c5ced51dd9 Mon Sep 17 00:00:00 2001 From: seohyonkim Date: Mon, 2 Jun 2025 16:00:50 +0200 Subject: [PATCH 2/5] add drvi to depenencies --- src/methods/drvi/config.vsh.yaml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/methods/drvi/config.vsh.yaml b/src/methods/drvi/config.vsh.yaml index e3cb78aa..c6355adc 100644 --- a/src/methods/drvi/config.vsh.yaml +++ b/src/methods/drvi/config.vsh.yaml @@ -66,9 +66,11 @@ engines: image: openproblems/base_python:1.0.0 # Add custom dependencies here (optional). For more information, see # https://viash.io/reference/config/engines/docker/#setup . - # setup: - # - type: python - # packages: numpy<2 + setup: + - type: python + pypi: + - drvi==0.0.10 + packages: numpy<2 runners: # This platform allows running the component natively From 26bf966fe49e9a573082a3850a66bca5e8c13b90 Mon Sep 17 00:00:00 2001 From: seohyonkim Date: Mon, 2 Jun 2025 16:04:40 +0200 Subject: [PATCH 3/5] add nvida image --- src/methods/drvi/config.vsh.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods/drvi/config.vsh.yaml b/src/methods/drvi/config.vsh.yaml index c6355adc..043d255f 100644 --- a/src/methods/drvi/config.vsh.yaml +++ b/src/methods/drvi/config.vsh.yaml @@ -63,7 +63,7 @@ resources: engines: # Specifications for the Docker image for this component. - type: docker - image: openproblems/base_python:1.0.0 + image: nvidia/cuda:12.3.2-runtime-ubuntu22.04 # Add custom dependencies here (optional). For more information, see # https://viash.io/reference/config/engines/docker/#setup . setup: From e62e6795795a00c205b7350a760a2d75a584fc9d Mon Sep 17 00:00:00 2001 From: seohyonkim Date: Mon, 2 Jun 2025 16:21:20 +0200 Subject: [PATCH 4/5] changes after feedback --- src/methods/drvi/config.vsh.yaml | 2 +- src/methods/drvi/script.py | 43 ++++++++++++++++---------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/methods/drvi/config.vsh.yaml b/src/methods/drvi/config.vsh.yaml index 043d255f..cb342adb 100644 --- a/src/methods/drvi/config.vsh.yaml +++ b/src/methods/drvi/config.vsh.yaml @@ -63,7 +63,7 @@ resources: engines: # Specifications for the Docker image for this component. - type: docker - image: nvidia/cuda:12.3.2-runtime-ubuntu22.04 + image: openproblems/base_pytorch_nvidia:1.0.0 # Add custom dependencies here (optional). For more information, see # https://viash.io/reference/config/engines/docker/#setup . setup: diff --git a/src/methods/drvi/script.py b/src/methods/drvi/script.py index 068c426e..3ca42f55 100644 --- a/src/methods/drvi/script.py +++ b/src/methods/drvi/script.py @@ -4,6 +4,8 @@ from drvi.model import DRVI from drvi.utils.misc import hvg_batch import pandas as pd +import numpy as np +import warnings ## VIASH START # Note: this section is auto-generated by viash at runtime. To edit it, make changes @@ -18,25 +20,27 @@ ## VIASH END print('Reading input files', flush=True) -adata = ad.read_h5ad(par['input']) -# Remove dataset with non-count values -adata = adata[adata.obs["batch"] != "Villani"].copy() - -print('Preprocess data', flush=True) -adata.X = adata.layers["counts"].copy() -sc.pp.normalize_total(adata) -sc.pp.log1p(adata) -adata -sc.pp.pca(adata) -sc.pp.neighbors(adata) -sc.tl.umap(adata) -adata +adata = read_anndata( + par['input'], + X='layers/counts', + obs='obs', + var='var', + uns='uns' +) +# Remove dataset with non-count values +counts = adata.layers["counts"] +if not np.allclose(counts, np.round(counts)): + warnings.warn( + "Non-integer values detected in 'counts' layer. " + "DRVI expects count data. Rounding to nearest integers as a workaround." + ) + adata.layers["counts"] = np.round(counts).astype(int) -# Batch aware HVG selection (method is obtained from scIB metrics) -hvg_genes = hvg_batch(adata, batch_key="batch", target_genes=2000, adataOut=False) -adata = adata[:, hvg_genes].copy() -adata +if par["n_hvg"]: + print(f"Select top {par['n_hvg']} high variable genes", flush=True) + idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]] + adata = adata[:, idx].copy() print('Train model with drVI', flush=True) # Setup data @@ -84,11 +88,6 @@ ) embed = ad.AnnData(model.get_latent_representation(), obs=adata.obs) -sc.pp.subsample(embed, fraction=1.0) # Shuffling for better visualization of overlapping colors - -sc.pp.neighbors(embed, n_neighbors=10, use_rep="X", n_pcs=embed.X.shape[1]) -sc.tl.umap(embed, spread=1.0, min_dist=0.5, random_state=123) -sc.pp.pca(embed) print("Store outputs", flush=True) output = ad.AnnData( From eacf2adf56b3516b3e4bec05b08f99c1ae1f2bda Mon Sep 17 00:00:00 2001 From: seohyonkim Date: Mon, 2 Jun 2025 18:35:37 +0200 Subject: [PATCH 5/5] working DRVI mehtod --- src/methods/drvi/config.vsh.yaml | 24 +++++++++++++------- src/methods/drvi/script.py | 38 ++++++++++++++++++++------------ 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/methods/drvi/config.vsh.yaml b/src/methods/drvi/config.vsh.yaml index cb342adb..4233224e 100644 --- a/src/methods/drvi/config.vsh.yaml +++ b/src/methods/drvi/config.vsh.yaml @@ -45,17 +45,23 @@ info: preferred_normalization: counts # Component-specific parameters (optional) -# arguments: -# - name: "--n_neighbors" -# type: "integer" -# default: 5 -# description: Number of neighbors to use. +arguments: + - name: --n_hvg + type: integer + default: 2000 + description: Number of highly variable genes to use. + - name: --n_epochs + type: integer + default: 400 + description: Number of epochs # Resources required to run the component resources: # The script of your component (required) - type: python_script path: script.py + - path: /src/utils/read_anndata_partial.py + # Additional resources your script needs (optional) # - type: file # path: weights.pt @@ -69,8 +75,10 @@ engines: setup: - type: python pypi: - - drvi==0.0.10 - packages: numpy<2 + - drvi-py==0.1.7 + - torch==2.3.0 + - torchvision==0.18.0 + # packages: runners: # This platform allows running the component natively @@ -78,4 +86,4 @@ runners: # Allows turning the component into a Nextflow module / pipeline. - type: nextflow directives: - label: [midtime,midmem,midcpu] + label: [midtime,midmem,lowcpu,gpu] diff --git a/src/methods/drvi/script.py b/src/methods/drvi/script.py index 3ca42f55..f7f2f1ea 100644 --- a/src/methods/drvi/script.py +++ b/src/methods/drvi/script.py @@ -6,21 +6,27 @@ import pandas as pd import numpy as np import warnings +import sys +import scipy.sparse ## VIASH START # Note: this section is auto-generated by viash at runtime. To edit it, make changes # in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. par = { 'input': 'resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad', - 'output': 'output.h5ad' + 'output': 'output.h5ad', + 'n_hvg': 2000, + 'n_epochs': 400 } meta = { 'name': 'drvi' } ## VIASH END -print('Reading input files', flush=True) +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata +print('Reading input files', flush=True) adata = read_anndata( par['input'], X='layers/counts', @@ -29,20 +35,26 @@ uns='uns' ) # Remove dataset with non-count values -counts = adata.layers["counts"] -if not np.allclose(counts, np.round(counts)): - warnings.warn( - "Non-integer values detected in 'counts' layer. " - "DRVI expects count data. Rounding to nearest integers as a workaround." - ) - adata.layers["counts"] = np.round(counts).astype(int) +counts = adata.X +import scipy.sparse + +if scipy.sparse.issparse(counts): + counts_dense = counts.toarray() +else: + counts_dense = counts + +if not np.allclose(counts_dense, np.round(counts_dense)): + warnings.warn("Non-integer values detected. Rounding to nearest integer.") + adata.X = np.round(counts_dense).astype(int) + +adata.layers["counts"] = adata.X.copy() if par["n_hvg"]: print(f"Select top {par['n_hvg']} high variable genes", flush=True) idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]] adata = adata[:, idx].copy() -print('Train model with drVI', flush=True) +print('Train model with DRVI', flush=True) # Setup data DRVI.setup_anndata( adata, @@ -68,11 +80,9 @@ ) model -n_epochs = 400 - # train the model model.train( - max_epochs=n_epochs, + max_epochs=par["n_epochs"], early_stopping=False, early_stopping_patience=20, # mps @@ -83,7 +93,7 @@ # # No need to provide `plan_kwargs` if n_epochs >= 400. plan_kwargs={ - "n_epochs_kl_warmup": n_epochs, + "n_epochs_kl_warmup": par["n_epochs"], }, )