From a49d642a2dafd622b258eb806049a9fa2a61e43c Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Mon, 4 Nov 2024 09:17:52 +0100 Subject: [PATCH 01/14] Add scPRINT component files --- src/methods/scprint/config.vsh.yaml | 57 +++++++++++++++++++++++++++++ src/methods/scprint/script.py | 31 ++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 src/methods/scprint/config.vsh.yaml create mode 100644 src/methods/scprint/script.py diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml new file mode 100644 index 00000000..a0befd51 --- /dev/null +++ b/src/methods/scprint/config.vsh.yaml @@ -0,0 +1,57 @@ +__merge__: ../../api/comp_method.yaml + +name: scprint +label: scPRINT +summary: scPRINT is a large transformer model built for the inference of gene networks +description: | + scPRINT is a large transformer model built for the inference of gene networks + (connections between genes explaining the cell's expression profile) from + scRNAseq data. + + It uses novel encoding and decoding of the cell expression profile and new + pre-training methodologies to learn a cell model. + + scPRINT can be used to perform the following analyses: + + - expression denoising: increase the resolution of your scRNAseq data + - cell embedding: generate a low-dimensional representation of your dataset + - label prediction: predict the cell type, disease, sequencer, sex, and + ethnicity of your cells + - gene network inference: generate a gene network from any cell or cell + cluster in your scRNAseq dataset + +references: + doi: + - 10.1101/2024.07.29.605556 + +links: + documentation: https://cantinilab.github.io/scPRINT/ + repository: https://github.com/cantinilab/scPRINT + +info: + preferred_normalization: counts + +# Component-specific parameters (optional) +# arguments: +# - name: "--n_neighbors" +# type: "integer" +# default: 5 +# description: Number of neighbors to use. + +resources: + - type: python_script + path: script.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + pypi: + - scPRINT + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py new file mode 100644 index 00000000..863ead90 --- /dev/null +++ b/src/methods/scprint/script.py @@ -0,0 +1,31 @@ +import anndata as ad + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + 'input': 'resources_test/.../input.h5ad', + 'output': 'output.h5ad' +} +meta = { + 'name': 'scprint' +} +## VIASH END + +print('Reading input files', flush=True) +input = ad.read_h5ad(par['input']) + +print('Preprocess data', flush=True) +# ... preprocessing ... + +print('Train model', flush=True) +# ... train model ... + +print('Generate predictions', flush=True) +# ... generate predictions ... + +print("Write output AnnData to file", flush=True) +output = ad.AnnData( + +) +output.write_h5ad(par['output'], compression='gzip') From d5b31d6f561c0fe55f77b266c8d893a62de11eb3 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Mon, 4 Nov 2024 12:08:25 +0100 Subject: [PATCH 02/14] Load and preprocess data for scPRINT --- src/methods/scprint/config.vsh.yaml | 11 ++++++++ src/methods/scprint/script.py | 39 ++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index a0befd51..1f805fa7 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -41,6 +41,7 @@ info: resources: - type: python_script path: script.py + - path: /src/utils/read_anndata_partial.py engines: - type: docker @@ -48,7 +49,17 @@ engines: setup: - type: python pypi: + - supabase==2.2.1 + - "gotrue>=2.1.0,<2.9.0" - scPRINT + - type: docker + run: lamin init --storage ./main --name main --schema bionty + - type: python + script: import bionty as bt; bt.core.sync_all_sources_to_latest() + - type: docker + run: lamin load main + - type: python + script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() runners: - type: executable diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 863ead90..20734acc 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -1,4 +1,6 @@ import anndata as ad +from scdataloader import Preprocessor +import sys ## VIASH START # Note: this section is auto-generated by viash at runtime. To edit it, make changes @@ -12,11 +14,40 @@ } ## VIASH END -print('Reading input files', flush=True) -input = ad.read_h5ad(par['input']) +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata -print('Preprocess data', flush=True) -# ... preprocessing ... +print(">>> Reading input data...", flush=True) +input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +print(">>> Setting ontology term IDs...", flush=True) +# For now, set all ontology term IDs to 'unknown' but these could be used for +# cellxgene datasets that have this information +print("NOTE: All ontology term IDs except organism are set to 'unknown'", flush=True) +if input.uns["dataset_organism"] == "homo_sapiens": + input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" +elif input.uns["dataset_organism"] == "mus_musculus": + input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" +else: + raise ValueError(f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'") +input.obs["self_reported_ethnicity_ontology_term_id"] = "unknown" +input.obs["disease_ontology_term_id"] = "unknown" +input.obs["cell_type_ontology_term_id"] = "unknown" +input.obs["development_stage_ontology_term_id"] = "unknown" +input.obs["tissue_ontology_term_id"] = "unknown" +input.obs["assay_ontology_term_id"] = "unknown" +input.obs["sex_ontology_term_id"] = "unknown" + +print('\n>>> Preprocessing data...', flush=True) +preprocessor = Preprocessor( + # Lower this threshold for test datasets + min_valid_genes_id = 1000 if input.n_vars < 2000 else 10000, + # Turn off cell filtering to return results for all cells + filter_cell_by_counts = False, + min_nnz_genes = False, + do_postp=False +) +processed = preprocessor(input) print('Train model', flush=True) # ... train model ... From c9e8a7494e57c7e76cde3dd8fc2b4c317879eed6 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 5 Nov 2024 07:52:30 +0100 Subject: [PATCH 03/14] Try running model... --- src/methods/scprint/config.vsh.yaml | 3 ++- src/methods/scprint/script.py | 40 ++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 1f805fa7..a2db90b8 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -48,9 +48,10 @@ engines: image: openproblems/base_pytorch_nvidia:1.0.0 setup: - type: python - pypi: + pip: - supabase==2.2.1 - "gotrue>=2.1.0,<2.9.0" + - huggingface_hub - scPRINT - type: docker run: lamin init --storage ./main --name main --schema bionty diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 20734acc..98f6c122 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -1,6 +1,9 @@ import anndata as ad from scdataloader import Preprocessor import sys +from huggingface_hub import hf_hub_download +from scprint.tasks import Embedder +from scprint import scPrint ## VIASH START # Note: this section is auto-generated by viash at runtime. To edit it, make changes @@ -20,7 +23,7 @@ print(">>> Reading input data...", flush=True) input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") -print(">>> Setting ontology term IDs...", flush=True) +print("\n>>> Setting ontology term IDs...", flush=True) # For now, set all ontology term IDs to 'unknown' but these could be used for # cellxgene datasets that have this information print("NOTE: All ontology term IDs except organism are set to 'unknown'", flush=True) @@ -49,6 +52,41 @@ ) processed = preprocessor(input) +print('\n>>> Downloading model...', flush=True) +model = "small" # TODO: Add other models +model_checkpoint_file = hf_hub_download( + repo_id="jkobject/scPRINT", + filename=f"{model}.ckpt" +) +print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) + + +model = scPrint.load_from_checkpoint( + model_checkpoint_file, + transformer = "normal", # TODO: Don't use this for GPUs with flashattention + precpt_gene_emb = None +) + +print('\n>>> Embedding data...', flush=True) +import torch +embedder = Embedder( + how="random expr", + max_len=4000, + add_zero_genes=0, + num_workers=8, # TODO: Detect and set number of workers + doclass = False, + doplot = False, + devices = None, + precision = "32", # TODO: Use float16 for GPUs + dtype = torch.float32, +) +print(embedder.precision) +print(embedder.dtype) + +embedded, metrics = embedder(model, processed, cache=False) +print(embedded) +print(metrics) + print('Train model', flush=True) # ... train model ... From feab72b2947599f42bbbdfacc396c30d31e2ab5d Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Fri, 22 Nov 2024 13:31:52 +0100 Subject: [PATCH 04/14] Adjust scPRINT installation --- src/methods/scprint/config.vsh.yaml | 6 ++---- src/methods/scprint/script.py | 5 ++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index a2db90b8..8b295194 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -49,16 +49,14 @@ engines: setup: - type: python pip: - - supabase==2.2.1 - - "gotrue>=2.1.0,<2.9.0" - huggingface_hub - - scPRINT + - scprint - type: docker run: lamin init --storage ./main --name main --schema bionty - type: python script: import bionty as bt; bt.core.sync_all_sources_to_latest() - type: docker - run: lamin load main + run: lamin load anonymous/main - type: python script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 98f6c122..e7dc4ae0 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -4,6 +4,7 @@ from huggingface_hub import hf_hub_download from scprint.tasks import Embedder from scprint import scPrint +import scprint ## VIASH START # Note: this section is auto-generated by viash at runtime. To edit it, make changes @@ -20,7 +21,9 @@ sys.path.append(meta["resources_dir"]) from read_anndata_partial import read_anndata -print(">>> Reading input data...", flush=True) +print(f"====== scPRINT version {scprint.__version__} ======", flush=True) + +print("\n>>> Reading input data...", flush=True) input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") print("\n>>> Setting ontology term IDs...", flush=True) From 195a34ec72860795c567040f4c8ccc3c221054f6 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Mon, 25 Nov 2024 13:54:47 +0100 Subject: [PATCH 05/14] Embed and save scPRINT output --- src/methods/scprint/script.py | 59 +++++++++++++++-------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index e7dc4ae0..c68d4b67 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -5,6 +5,7 @@ from scprint.tasks import Embedder from scprint import scPrint import scprint +from torch import float32 as torch_float32 ## VIASH START # Note: this section is auto-generated by viash at runtime. To edit it, make changes @@ -25,24 +26,13 @@ print("\n>>> Reading input data...", flush=True) input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") - -print("\n>>> Setting ontology term IDs...", flush=True) -# For now, set all ontology term IDs to 'unknown' but these could be used for -# cellxgene datasets that have this information -print("NOTE: All ontology term IDs except organism are set to 'unknown'", flush=True) if input.uns["dataset_organism"] == "homo_sapiens": input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" elif input.uns["dataset_organism"] == "mus_musculus": input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" else: raise ValueError(f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'") -input.obs["self_reported_ethnicity_ontology_term_id"] = "unknown" -input.obs["disease_ontology_term_id"] = "unknown" -input.obs["cell_type_ontology_term_id"] = "unknown" -input.obs["development_stage_ontology_term_id"] = "unknown" -input.obs["tissue_ontology_term_id"] = "unknown" -input.obs["assay_ontology_term_id"] = "unknown" -input.obs["sex_ontology_term_id"] = "unknown" +adata = input.copy() print('\n>>> Preprocessing data...', flush=True) preprocessor = Preprocessor( @@ -51,9 +41,11 @@ # Turn off cell filtering to return results for all cells filter_cell_by_counts = False, min_nnz_genes = False, - do_postp=False + do_postp=False, + # Skip ontology checks + skip_validate=True ) -processed = preprocessor(input) +adata = preprocessor(adata) print('\n>>> Downloading model...', flush=True) model = "small" # TODO: Add other models @@ -62,8 +54,6 @@ filename=f"{model}.ckpt" ) print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) - - model = scPrint.load_from_checkpoint( model_checkpoint_file, transformer = "normal", # TODO: Don't use this for GPUs with flashattention @@ -71,7 +61,6 @@ ) print('\n>>> Embedding data...', flush=True) -import torch embedder = Embedder( how="random expr", max_len=4000, @@ -79,25 +68,27 @@ num_workers=8, # TODO: Detect and set number of workers doclass = False, doplot = False, - devices = None, precision = "32", # TODO: Use float16 for GPUs - dtype = torch.float32, + dtype = torch_float32, ) -print(embedder.precision) -print(embedder.dtype) - -embedded, metrics = embedder(model, processed, cache=False) -print(embedded) -print(metrics) - -print('Train model', flush=True) -# ... train model ... +embedded, _ = embedder(model, adata, cache=False) -print('Generate predictions', flush=True) -# ... generate predictions ... - -print("Write output AnnData to file", flush=True) +print("\n>>> Storing output...", flush=True) output = ad.AnnData( - + obs=input.obs[[]], + var=input.var[[]], + obsm={ + "X_emb": embedded.obsm["scprint"], + }, + uns={ + "dataset_id": input.uns["dataset_id"], + "normalization_id": input.uns["normalization_id"], + "method_id": meta["name"], + }, ) -output.write_h5ad(par['output'], compression='gzip') +print(output) + +print("\n>>> Writing output AnnData to file...", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Done!", flush=True) From 8eaea116372e55087dbb6e48b0d34301d68c27d7 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Mon, 25 Nov 2024 14:03:44 +0100 Subject: [PATCH 06/14] Detect available cores --- src/methods/scprint/script.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index c68d4b67..e86f41cb 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -6,6 +6,7 @@ from scprint import scPrint import scprint from torch import float32 as torch_float32 +import os ## VIASH START # Note: this section is auto-generated by viash at runtime. To edit it, make changes @@ -61,11 +62,13 @@ ) print('\n>>> Embedding data...', flush=True) +n_cores_available = len(os.sched_getaffinity(0)) +print(f"Using {n_cores_available} cores") embedder = Embedder( how="random expr", max_len=4000, add_zero_genes=0, - num_workers=8, # TODO: Detect and set number of workers + num_workers=n_cores_available, doclass = False, doplot = False, precision = "32", # TODO: Use float16 for GPUs From d271396bc68b40d2272c7499f2109cfa5e5a32ce Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Mon, 25 Nov 2024 14:16:13 +0100 Subject: [PATCH 07/14] Adjust arguments if GPU available --- src/methods/scprint/script.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index e86f41cb..9f42c217 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -5,7 +5,7 @@ from scprint.tasks import Embedder from scprint import scPrint import scprint -from torch import float32 as torch_float32 +import torch import os ## VIASH START @@ -57,13 +57,21 @@ print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) model = scPrint.load_from_checkpoint( model_checkpoint_file, - transformer = "normal", # TODO: Don't use this for GPUs with flashattention + transformer = "normal", # Don't use this for GPUs with flashattention precpt_gene_emb = None ) print('\n>>> Embedding data...', flush=True) +if torch.cuda.is_available(): + print("CUDA is available, using GPU", flush=True) + precision = "16" + dtype = torch.float16 +else: + print("CUDA is not available, using CPU", flush=True) + precision = "32" + dtype = torch.float32 n_cores_available = len(os.sched_getaffinity(0)) -print(f"Using {n_cores_available} cores") +print(f"Using {n_cores_available} worker cores") embedder = Embedder( how="random expr", max_len=4000, @@ -71,8 +79,8 @@ num_workers=n_cores_available, doclass = False, doplot = False, - precision = "32", # TODO: Use float16 for GPUs - dtype = torch_float32, + precision = precision, + dtype = dtype, ) embedded, _ = embedder(model, adata, cache=False) From 37a3f502a0088e7bb034a35127e466bdf03a2fd7 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Mon, 25 Nov 2024 14:36:43 +0100 Subject: [PATCH 08/14] Add model argument to scPRINT --- src/methods/scprint/config.vsh.yaml | 20 ++++++++++++++------ src/methods/scprint/script.py | 8 ++++---- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 8b295194..0d2b7567 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -30,13 +30,21 @@ links: info: preferred_normalization: counts + method_types: [embedding] + variants: + scprint_large: + model: "large" + scprint_medium: + model: "medium" + scprint_small: + model: "small" -# Component-specific parameters (optional) -# arguments: -# - name: "--n_neighbors" -# type: "integer" -# default: 5 -# description: Number of neighbors to use. +arguments: + - name: "--model" + type: "string" + description: String representing the Geneformer model to use + choices: ["large", "medium", "small"] + default: "large" resources: - type: python_script diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 9f42c217..3750e574 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -13,7 +13,8 @@ # in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. par = { 'input': 'resources_test/.../input.h5ad', - 'output': 'output.h5ad' + 'output': 'output.h5ad', + "model": "large", } meta = { 'name': 'scprint' @@ -48,11 +49,10 @@ ) adata = preprocessor(adata) -print('\n>>> Downloading model...', flush=True) -model = "small" # TODO: Add other models +print(f"\n>>> Downloading '{par['model']}' model...", flush=True) model_checkpoint_file = hf_hub_download( repo_id="jkobject/scPRINT", - filename=f"{model}.ckpt" + filename=f"{par['model']}.ckpt" ) print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) model = scPrint.load_from_checkpoint( From 62704e464675c83e693b4fc8b0f02a021931799a Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Mon, 25 Nov 2024 15:31:46 +0100 Subject: [PATCH 09/14] Add scPRINT to benchmark workflow --- src/workflows/run_benchmark/config.vsh.yaml | 1 + src/workflows/run_benchmark/main.nf | 1 + src/workflows/run_benchmark/main.nf.test | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+) create mode 100644 src/workflows/run_benchmark/main.nf.test diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 1047ea87..75c93003 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -96,6 +96,7 @@ dependencies: - name: methods/scanvi - name: methods/scgpt - name: methods/scimilarity + - name: methods/scprint - name: methods/scvi - name: methods/uce # metrics diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index aaacb434..afcb968c 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -33,6 +33,7 @@ methods = [ scimilarity.run( args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")] ), + scprint, scvi, uce.run( args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")] diff --git a/src/workflows/run_benchmark/main.nf.test b/src/workflows/run_benchmark/main.nf.test new file mode 100644 index 00000000..3fa35887 --- /dev/null +++ b/src/workflows/run_benchmark/main.nf.test @@ -0,0 +1,19 @@ +workflow auto { + findStates(params, meta.config) + | view{"In: $it"} + | meta.workflow.run( + auto: [publish: "state"] + ) +} + +workflow run_wf { + take: + input_ch + + main: + output_ch = input_ch + | view{"Mid: $it"} + + emit: + output_ch +} From 4774c98d2c7c64d1e072a36ad4d5e72da85b6ece Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Mon, 25 Nov 2024 16:08:09 +0100 Subject: [PATCH 10/14] Make scPRINT inherit from base method Model is too large for CI tests --- src/methods/scprint/config.vsh.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 0d2b7567..4e7b7ca3 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -1,4 +1,4 @@ -__merge__: ../../api/comp_method.yaml +__merge__: /src/api/base_method.yaml name: scprint label: scPRINT From c3ce8df1f0a48171ab0e4d0ede202715c1ef69d5 Mon Sep 17 00:00:00 2001 From: Robrecht Cannoodt Date: Tue, 26 Nov 2024 13:16:49 +0100 Subject: [PATCH 11/14] style code --- src/methods/scprint/script.py | 43 ++++++++++++++++------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 3750e574..82198cfe 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -9,16 +9,12 @@ import os ## VIASH START -# Note: this section is auto-generated by viash at runtime. To edit it, make changes -# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. par = { - 'input': 'resources_test/.../input.h5ad', - 'output': 'output.h5ad', - "model": "large", -} -meta = { - 'name': 'scprint' + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "large", } +meta = {"name": "scprint"} ## VIASH END sys.path.append(meta["resources_dir"]) @@ -33,35 +29,36 @@ elif input.uns["dataset_organism"] == "mus_musculus": input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" else: - raise ValueError(f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'") + raise ValueError( + f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" + ) adata = input.copy() -print('\n>>> Preprocessing data...', flush=True) +print("\n>>> Preprocessing data...", flush=True) preprocessor = Preprocessor( # Lower this threshold for test datasets - min_valid_genes_id = 1000 if input.n_vars < 2000 else 10000, + min_valid_genes_id=1000 if input.n_vars < 2000 else 10000, # Turn off cell filtering to return results for all cells - filter_cell_by_counts = False, - min_nnz_genes = False, + filter_cell_by_counts=False, + min_nnz_genes=False, do_postp=False, # Skip ontology checks - skip_validate=True + skip_validate=True, ) adata = preprocessor(adata) print(f"\n>>> Downloading '{par['model']}' model...", flush=True) model_checkpoint_file = hf_hub_download( - repo_id="jkobject/scPRINT", - filename=f"{par['model']}.ckpt" + repo_id="jkobject/scPRINT", filename=f"{par['model']}.ckpt" ) print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) model = scPrint.load_from_checkpoint( model_checkpoint_file, - transformer = "normal", # Don't use this for GPUs with flashattention - precpt_gene_emb = None + transformer="normal", # Don't use this for GPUs with flashattention + precpt_gene_emb=None, ) -print('\n>>> Embedding data...', flush=True) +print("\n>>> Embedding data...", flush=True) if torch.cuda.is_available(): print("CUDA is available, using GPU", flush=True) precision = "16" @@ -77,10 +74,10 @@ max_len=4000, add_zero_genes=0, num_workers=n_cores_available, - doclass = False, - doplot = False, - precision = precision, - dtype = dtype, + doclass=False, + doplot=False, + precision=precision, + dtype=dtype, ) embedded, _ = embedder(model, adata, cache=False) From 18b86004eda15a65b73569e5a3bda0d971b5acc9 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 26 Nov 2024 13:56:20 +0100 Subject: [PATCH 12/14] Apply suggestions from code review Co-authored-by: Robrecht Cannoodt --- src/methods/scprint/config.vsh.yaml | 14 +++++++++----- src/methods/scprint/script.py | 13 ++++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 4e7b7ca3..12100dfb 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -33,18 +33,22 @@ info: method_types: [embedding] variants: scprint_large: - model: "large" + model_name: "large" scprint_medium: - model: "medium" + model_name: "medium" scprint_small: - model: "small" + model_name: "small" arguments: - - name: "--model" + - name: "--model_name" type: "string" - description: String representing the Geneformer model to use + description: Which model to use. Not used if --model is provided. choices: ["large", "medium", "small"] default: "large" + - name: --model + type: file + description: Path to the scPRINT model. + required: false resources: - type: python_script diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 82198cfe..62a2b2df 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -12,7 +12,8 @@ par = { "input": "resources_test/.../input.h5ad", "output": "output.h5ad", - "model": "large", + "model_name": "large", + "model": None, } meta = {"name": "scprint"} ## VIASH END @@ -47,10 +48,12 @@ ) adata = preprocessor(adata) -print(f"\n>>> Downloading '{par['model']}' model...", flush=True) -model_checkpoint_file = hf_hub_download( - repo_id="jkobject/scPRINT", filename=f"{par['model']}.ckpt" -) +model_checkpoint_file = par["model"] +if model_checkpoint_file is None: + print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) + model_checkpoint_file = hf_hub_download( + repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" + ) print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) model = scPrint.load_from_checkpoint( model_checkpoint_file, From 5162d00510142f26d69e055abd296468d46a9cb4 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 26 Nov 2024 13:45:15 +0100 Subject: [PATCH 13/14] Remove test workflow file --- src/workflows/run_benchmark/main.nf.test | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 src/workflows/run_benchmark/main.nf.test diff --git a/src/workflows/run_benchmark/main.nf.test b/src/workflows/run_benchmark/main.nf.test deleted file mode 100644 index 3fa35887..00000000 --- a/src/workflows/run_benchmark/main.nf.test +++ /dev/null @@ -1,19 +0,0 @@ -workflow auto { - findStates(params, meta.config) - | view{"In: $it"} - | meta.workflow.run( - auto: [publish: "state"] - ) -} - -workflow run_wf { - take: - input_ch - - main: - output_ch = input_ch - | view{"Mid: $it"} - - emit: - output_ch -} From 7d4906a4f13e86d390ea9d07acc8bf26f72f1b45 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 26 Nov 2024 13:58:29 +0100 Subject: [PATCH 14/14] Fix test data path --- src/methods/scprint/script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 62a2b2df..6c1d6b96 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -10,7 +10,7 @@ ## VIASH START par = { - "input": "resources_test/.../input.h5ad", + "input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", "output": "output.h5ad", "model_name": "large", "model": None,