Skip to content

Implementation of CellPLM #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/methods/cellplm/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
__merge__: ../../api/base_method.yaml


name: cellplm

label: CellPLM

summary: "A foundation model pre-trained with cells as tokens."

description: |
CellPLM is a pre-trained language model specifically designed for single-cell analysis that leverages the principles of natural language processing (NLP) to understand and process single-cell gene expression data.
references:
doi:
- 10.1101/2023.10.03.560734
links:

documentation: https://github.com/OmicsML/CellPLM/tree/main/tutorials

repository: https://github.com/OmicsML/CellPLM


info:
method_types: [embedding]
preferred_normalization: counts

arguments:
- name: --model
type: string
description: String giving the CellPLM model to use
choices: ["20231027_85M"]
default: "20231027_85M"

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pypi:
- gdown
- scgpt
- cellplm

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
93 changes: 93 additions & 0 deletions src/methods/cellplm/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import sys
import tempfile
import scanpy as sc
import anndata as ad
import gdown
import torch

import warnings
warnings.filterwarnings("ignore")
from CellPLM.utils import set_seed

import numpy as np
import anndata as ad
from CellPLM.pipeline.cell_embedding import CellEmbeddingPipeline

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.

par = {
'input': 'resources_test/.../input.h5ad',
'output': 'output.h5ad',
"model": "20231027_85M",
}
meta = {
'name': 'cellplm'
}
## VIASH END

sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

set_seed(24)
PRETRAIN_VERSION = par['model']
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("\n>>> Reading input files...", flush=True)
print(f"Input H5AD file: '{par['input']}'", flush=True)
adata = read_anndata(
par['input'],
X='layers/normalized',
obs='obs',
var='var',
uns='uns'
)

if adata.uns["dataset_organism"] != "homo_sapiens":
raise ValueError(
f"CellPLM can only be used with human data "
f"(dataset_organism == \"{adata.uns['dataset_organism']}\")"
)

print(adata, flush=True)

print('Train model', flush=True)
# ... train model ...

drive_path = f"https://drive.google.com/drive/folders/1C2fVNEKX3plHnagaTwpuPW5tpwv1up9G?usp=sharing"
model_dir = tempfile.TemporaryDirectory()
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir.name, quiet=True)
print(f"Model directory: '{model_dir.name}'", flush=True)

pipeline = CellEmbeddingPipeline(pretrain_prefix=PRETRAIN_VERSION, # Specify the pretrain checkpoint to load
pretrain_directory=model_dir.name)

# DEVICE ='cpu'
embedding = pipeline.predict(adata, # An AnnData object
device=DEVICE) # Specify a gpu or cpu for model inference

embedding = embedding.cpu().numpy()

print('Generate predictions', flush=True)
# ... generate predictions ...

output = ad.AnnData(
obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": embedding,
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

output.write_h5ad(par['output'], compression='gzip')

print("\n>>> Cleaning up temporary directories...", flush=True)
model_dir.cleanup()
Loading