Skip to content

Add scGPT #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/methods/scgpt/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
__merge__: ../../api/comp_method.yaml

name: scgpt
label: scGPT
summary: "A foundation model for single-cell biology"
description: |
scGPT is a foundation model for single-cell biology based on a generative
pre-trained transformer and trained on a repository of over 33 million cells.
Here, we use zero-shot output from a pre-trained model to get an integrated
embedding for the batch integration task.
references:
doi:
- 10.1038/s41592-024-02201-0
links:
documentation: https://scgpt.readthedocs.io/en/latest/
repository: https://github.com/bowang-lab/scGPT

info:
method_types: [embedding]
preferred_normalization: counts
variants:
scgpt_default:
scgpt_cp:
model: "scGPT_CP"

arguments:
- name: --model
type: string
description: String giving the scGPT model to use
choices: ["scGPT_human", "scGPT_CP"]
default: "scGPT_human"
- name: --n_hvg
type: integer
default: 3000
description: Number of highly variable genes to use.

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
setup:
- type: python
pypi:
- gdown
- scgpt # Install from PyPI to get dependencies
- type: docker
# Force re-installing from GitHub to get bug fixes
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
92 changes: 92 additions & 0 deletions src/methods/scgpt/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import sys
import tempfile

import anndata as ad
import gdown
import scgpt
import torch

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
"input": "resources_test/.../input.h5ad",
"output": "output.h5ad",
"model": "scGPT_human",
"n_hvg": 3000,
}
meta = {"name": "scgpt"}
## VIASH END

print(f"====== scGPT version {scgpt.__version__} ======", flush=True)

sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

print("\n>>> Reading input files...", flush=True)
print(f"Input H5AD file: '{par['input']}'", flush=True)
adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")

if adata.uns["dataset_organism"] != "homo_sapiens":
raise ValueError(
f"scGPT can only be used with human data "
f"(dataset_organism == \"{adata.uns['dataset_organism']}\")"
)

print(adata, flush=True)

print("\n>>> Preprocessing data...", flush=True)
if par["n_hvg"]:
print(f"Selecting top {par['n_hvg']} highly variable genes", flush=True)
idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][: par["n_hvg"]]
adata = adata[:, idx].copy()

print(adata, flush=True)

print(f"\n>>> Downloading '{par['model']}' model...", flush=True)
model_drive_ids = {
"scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB",
}
drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}"
model_dir = tempfile.TemporaryDirectory()
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir.name, quiet=True)
print(f"Model directory: '{model_dir.name}'", flush=True)

print("\n>>> Embedding data...", flush=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: '{device}'", flush=True)
embedded = scgpt.tasks.embed_data(
adata,
model_dir.name,
gene_col="feature_name",
batch_size=64,
use_fast_transformer=False, # Disable fast-attn as not installed
device=device,
return_new_adata=True,
)

print("\n>>> Storing output...", flush=True)
output = ad.AnnData(
obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": embedded.X,
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

print("\n>>> Writing output to file...", flush=True)
print(f"Output H5AD file: '{par['output']}'", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Cleaning up temporary directories...", flush=True)
model_dir.cleanup()

print("\n>>> Done!", flush=True)
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ dependencies:
- name: methods/scanorama
- name: methods/scanvi
- name: methods/scimilarity
- name: methods/scgpt
- name: methods/scvi
- name: methods/uce
# metrics
Expand Down
3 changes: 2 additions & 1 deletion src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ methods = [
scimilarity.run(
args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")]
),
scgpt
scvi,
uce.run(
args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")]
),
)
]

// construct list of metrics
Expand Down