Skip to content

Add model path argument to scGPT #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/methods/scgpt/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__merge__: ../../api/comp_method.yaml
__merge__: ../../api/base_method.yaml

name: scgpt
label: scGPT
Expand All @@ -24,11 +24,18 @@ info:
model: "scGPT_CP"

arguments:
- name: --model
- name: --model_name
type: string
description: String giving the scGPT model to use
description: String giving the name of the scGPT model to use
choices: ["scGPT_human", "scGPT_CP"]
default: "scGPT_human"
- name: --model
type: file
description: |
Path to the directory containing the scGPT model specified by model_name
or a .zip/.tar.gz archive to extract. If not given the model will be
downloaded.
required: false
- name: --n_hvg
type: integer
default: 3000
Expand Down
62 changes: 49 additions & 13 deletions src/methods/scgpt/script.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import sys
import tarfile
import tempfile
import zipfile

import anndata as ad
import gdown
Expand All @@ -12,6 +15,7 @@
par = {
"input": "resources_test/.../input.h5ad",
"output": "output.h5ad",
"model_name": "scGPT_human",
"model": "scGPT_human",
"n_hvg": 3000,
}
Expand Down Expand Up @@ -43,23 +47,54 @@

print(adata, flush=True)

print(f"\n>>> Downloading '{par['model']}' model...", flush=True)
model_drive_ids = {
"scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB",
}
drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}"
model_dir = tempfile.TemporaryDirectory()
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir.name, quiet=True)
print(f"Model directory: '{model_dir.name}'", flush=True)
if par["model"] is None:
print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True)
model_drive_ids = {
"scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB",
}
drive_path = (
f"https://drive.google.com/drive/folders/{model_drive_ids[par['model_name']]}"
)
model_temp = tempfile.TemporaryDirectory()
model_dir = model_temp.name
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir, quiet=True)
else:
if os.path.isdir(par["model"]):
print(f"\n>>> Using model directory...", flush=True)
model_temp = None
model_dir = par["model"]
else:
model_temp = tempfile.TemporaryDirectory()
model_dir = model_temp.name

if zipfile.is_zipfile(par["model"]):
print(f"\n>>> Extracting model from .zip...", flush=True)
print(f".zip path: '{par['model']}'", flush=True)
with zipfile.ZipFile(par["model"], "r") as zip_file:
zip_file.extractall(model_dir)
elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(
".tar.gz"
):
print(f"\n>>> Extracting model from .tar.gz...", flush=True)
print(f".tar.gz path: '{par['model']}'", flush=True)
with tarfile.open(par["model"], "r:gz") as tar_file:
tar_file.extractall(model_dir)
model_dir = os.path.join(model_dir, os.listdir(model_dir)[0])
else:
raise ValueError(
f"The 'model' argument should be a directory a .zip file or a .tar.gz file"
)

print(f"Model directory: '{model_dir}'", flush=True)

print("\n>>> Embedding data...", flush=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: '{device}'", flush=True)
embedded = scgpt.tasks.embed_data(
adata,
model_dir.name,
model_dir,
gene_col="feature_name",
batch_size=64,
use_fast_transformer=False, # Disable fast-attn as not installed
Expand All @@ -86,7 +121,8 @@
print(f"Output H5AD file: '{par['output']}'", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Cleaning up temporary directories...", flush=True)
model_dir.cleanup()
if model_temp is not None:
print("\n>>> Cleaning up temporary directories...", flush=True)
model_temp.cleanup()

print("\n>>> Done!", flush=True)
4 changes: 3 additions & 1 deletion src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ methods = [
scalex,
scanorama,
scanvi,
scgpt,
scgpt.run(
args: [model_path: file("s3://openproblems-work/cache/scGPT_human.zip")]
),
scimilarity.run(
args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")]
),
Expand Down
Loading