Skip to content

Add Geneformer #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f47eff9
Add cxg_immune_cell_atlas as a test resource
lazappi Oct 2, 2024
175a3f8
Add SCimilarity component
lazappi Oct 2, 2024
a9931e1
Add SCimiliarity to benchmark workflow
lazappi Oct 2, 2024
84394c3
Update script to extract model
lazappi Oct 8, 2024
5e0038c
Add SCimilarity model path to benchmark workflow
lazappi Oct 8, 2024
c927cb1
Add base_method API to disable tests for SCimilarity
lazappi Oct 8, 2024
5c74f37
Replace cxg_mouse_pancreas_atlas with cxg_immune_cell_atlas
lazappi Oct 8, 2024
7b6fea3
Style SCimiliarity script
lazappi Oct 8, 2024
cca5715
Remove test resources from SCimiliarity config
lazappi Oct 8, 2024
30e8b14
Fix file names in test resources state.yaml
lazappi Oct 11, 2024
b2188b5
Add scimilarity as dependency to benchmark workflow
lazappi Oct 11, 2024
ca95e44
Update compute environment
lazappi Oct 16, 2024
ce57335
Update model file path
lazappi Oct 21, 2024
c60a7da
Create geneformer files
lazappi Oct 22, 2024
99a7078
Set SCimilarity name in Python script
lazappi Oct 22, 2024
f4b98e1
Adjust container settings
lazappi Oct 22, 2024
fbaf8b1
Download dictionary files in script
lazappi Oct 22, 2024
1e7dfa0
Merge remote-tracking branch 'origin/main' into feature/no-ref/add-Ge…
lazappi Oct 22, 2024
b1d1520
Prepare and tokenize data, attempt to embed
lazappi Oct 23, 2024
e3abd30
Store and output embedding
lazappi Oct 28, 2024
1e00a52
Add Geneformer to benchmark workflow
lazappi Oct 29, 2024
113a892
Add argument to select model version to use
lazappi Oct 29, 2024
f26eb24
Style Geneformer script
lazappi Oct 29, 2024
98789b1
Make Geneformer inherit from base_method for tests
lazappi Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/methods/geneformer/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
__merge__: /src/api/base_method.yaml

name: geneformer
label: Geneformer
summary: Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes
description: |
Geneformer is a foundation transformer model pretrained on a large-scale
corpus of single cell transcriptomes to enable context-aware predictions in
network biology. For this task, Geneformer is used to create a batch-corrected
cell embedding.
references:
doi:
- 10.1038/s41586-023-06139-9
- 10.1101/2024.08.16.608180
links:
documentation: https://geneformer.readthedocs.io/en/latest/index.html
repository: https://huggingface.co/ctheodoris/Geneformer

info:
preferred_normalization: counts
method_types: [embedding]
variants:
geneformer_12L_95M_i4096:
model: "gf-12L-95M-i4096"
geneformer_6L_30M_i2048:
model: "gf-6L-30M-i2048"
geneformer_12L_30M_i2048:
model: "gf-12L-30M-i2048"
geneformer_20L_95M_i4096:
model: "gf-20L-95M-i4096"

arguments:
- name: "--model"
type: "string"
description: String representing the Geneformer model to use
choices: ["gf-6L-30M-i2048", "gf-12L-30M-i2048", "gf-12L-95M-i4096", "gf-20L-95M-i4096"]
default: "gf-12L-95M-i4096"

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pip:
- pyarrow<15.0.0a0,>=14.0.1
- huggingface_hub
- git+https://huggingface.co/ctheodoris/Geneformer.git

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
154 changes: 154 additions & 0 deletions src/methods/geneformer/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
import sys
from tempfile import TemporaryDirectory

import anndata as ad
import numpy as np
import pandas as pd
from geneformer import EmbExtractor, TranscriptomeTokenizer
from huggingface_hub import hf_hub_download

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
"input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad",
"output": "output.h5ad",
"model": "gf-12L-95M-i4096",
}
meta = {"name": "geneformer"}
## VIASH END

n_processors = os.cpu_count()

print(">>> Reading input...", flush=True)
sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")

if adata.uns["dataset_organism"] != "homo_sapiens":
raise ValueError(
f"Geneformer can only be used with human data "
f"(dataset_organism == '{adata.uns['dataset_organism']}')"
)

is_ensembl = all(var_name.startswith("ENSG") for var_name in adata.var_names)
if not is_ensembl:
raise ValueError(f"Geneformer requires adata.var_names to contain ENSEMBL gene ids")

print(f">>> Getting settings for model '{par['model']}'...", flush=True)
model_split = par["model"].split("-")
model_details = {
"layers": model_split[1],
"dataset": model_split[2],
"input_size": int(model_split[3][1:]),
}
print(model_details, flush=True)

print(">>> Getting model dictionary files...", flush=True)
if model_details["dataset"] == "95M":
dictionaries_subfolder = "geneformer"
elif model_details["dataset"] == "30M":
dictionaries_subfolder = "geneformer/gene_dictionaries_30m"
else:
raise ValueError(f"Invalid model dataset: {model_details['dataset']}")
print(f"Dictionaries subfolder: '{dictionaries_subfolder}'")

dictionary_files = {
"ensembl_mapping": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"ensembl_mapping_dict_gc{model_details['dataset']}.pkl",
),
"gene_median": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"gene_median_dictionary_gc{model_details['dataset']}.pkl",
),
"gene_name_id": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"gene_name_id_dict_gc{model_details['dataset']}.pkl",
),
"token": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"token_dictionary_gc{model_details['dataset']}.pkl",
),
}

print(">>> Creating working directory...", flush=True)
work_dir = TemporaryDirectory()
input_dir = os.path.join(work_dir.name, "input")
os.makedirs(input_dir)
tokenized_dir = os.path.join(work_dir.name, "tokenized")
os.makedirs(tokenized_dir)
embedding_dir = os.path.join(work_dir.name, "embedding")
os.makedirs(embedding_dir)
print(f"Working directory: '{work_dir.name}'", flush=True)

print(">>> Preparing data...", flush=True)
adata.var["ensembl_id"] = adata.var_names
adata.obs["n_counts"] = np.ravel(adata.X.sum(axis=1))
adata.write_h5ad(os.path.join(input_dir, "input.h5ad"))
print(adata)

print(">>> Tokenizing data...", flush=True)
special_token = model_details["dataset"] == "95M"
print(f"Input size: {model_details['input_size']}, Special token: {special_token}")
tokenizer = TranscriptomeTokenizer(
nproc=n_processors,
model_input_size=model_details["input_size"],
special_token=special_token,
gene_median_file=dictionary_files["gene_median"],
token_dictionary_file=dictionary_files["token"],
gene_mapping_file=dictionary_files["ensembl_mapping"],
)
tokenizer.tokenize_data(input_dir, tokenized_dir, "tokenized", file_format="h5ad")

print(f">>> Getting model files for model '{par['model']}'...", flush=True)
model_files = {
"model": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=par["model"],
filename="model.safetensors",
),
"config": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=par["model"],
filename="config.json",
),
}
model_dir = os.path.dirname(model_files["model"])

print(">>> Extracting embeddings...", flush=True)
embedder = EmbExtractor(
emb_mode="cell", max_ncells=None, token_dictionary_file=dictionary_files["token"]
)
embedder.extract_embs(
model_dir,
os.path.join(tokenized_dir, "tokenized.dataset"),
embedding_dir,
"embedding",
)
embedding = pd.read_csv(os.path.join(embedding_dir, "embedding.csv")).to_numpy()

print(">>> Storing outputs...", flush=True)
output = ad.AnnData(
obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": embedding,
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

print(">>> Writing output AnnData to file...", flush=True)
output.write_h5ad(par["output"], compression="gzip")
print(">>> Done!")
2 changes: 1 addition & 1 deletion src/methods/scimilarity/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"model": "model_v1.1",
}
meta = {
"name": "scvi",
"name": "scimilarity",
}
## VIASH END

Expand Down
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ dependencies:
- name: methods/batchelor_mnn_correct
- name: methods/bbknn
- name: methods/combat
- name: methods/geneformer
- name: methods/harmony
- name: methods/harmonypy
- name: methods/liger
Expand Down
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ methods = [
batchelor_mnn_correct,
bbknn,
combat,
geneformer,
harmony,
harmonypy,
liger,
Expand Down