Skip to content

Add UCE method #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/run_benchmark/run_full_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ input_states: resources/datasets/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["uce"]}'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I excluded UCE from the local benchmark scripts because it requires more memory than is allowed by the local labels config

HERE

# run the benchmark
Expand Down
1 change: 1 addition & 0 deletions scripts/run_benchmark/run_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["uce"]}'
HERE

nextflow run . \
Expand Down
45 changes: 45 additions & 0 deletions src/methods/uce/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
__merge__: ../../api/base_method.yaml

name: uce
label: UCE
summary: UCE offers a unified biological latent space that can represent any cell
description: |
Universal Cell Embedding (UCE) is a single-cell foundation model that offers a
unified biological latent space that can represent any cell, regardless of
tissue or species
references:
doi:
- 10.1101/2023.11.28.568918
links:
documentation: https://github.com/snap-stanford/UCE/blob/main/README.md
repository: https://github.com/snap-stanford/UCE

info:
method_types: [embedding]
preferred_normalization: counts

arguments:
- name: --model
type: file
description: Path to the directory containing UCE model files or a .zip/.tar.gz archive
required: true

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pypi:
- accelerate==0.24.0
- type: docker
run: "git clone https://github.com/snap-stanford/UCE.git"
runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
211 changes: 211 additions & 0 deletions src/methods/uce/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import os
import pickle
import sys
import tarfile
import tempfile
import zipfile
from argparse import Namespace

import anndata as ad
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator

# Code has hardcoded paths that only work correctly inside the UCE directory
if os.path.isdir("UCE"):
# For executable we can work inside the UCE directory
os.chdir("UCE")
else:
# For Nextflow we need to copy files to the Nextflow working directory
print(">>> Copying UCE files to local directory...", flush=True)
import shutil

shutil.copytree("/workspace/UCE", ".", dirs_exist_ok=True)
Comment on lines +15 to +24
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite hacky but it was the only way I could get it to work. Happy to hear other suggestions.


# Append current directory to import UCE functions
sys.path.append(".")
from data_proc.data_utils import (
adata_path_to_prot_chrom_starts,
get_spec_chrom_csv,
get_species_to_pe,
process_raw_anndata,
)
from evaluate import run_eval

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
"input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad",
"output": "output.h5ad",
}
meta = {"name": "uce"}
## VIASH END

print(">>> Reading input...", flush=True)
sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")

if adata.uns["dataset_organism"] == "homo_sapiens":
species = "human"
elif adata.uns["dataset_organism"] == "mus_musculus":
species = "mouse"
else:
raise ValueError(f"Species '{adata.uns['dataset_organism']}' not yet implemented")

print("\n>>> Creating working directory...", flush=True)
work_dir = tempfile.TemporaryDirectory()
print(f"Working directory: '{work_dir.name}'", flush=True)

print("\n>>> Getting model files...", flush=True)
if os.path.isdir(par["model"]):
model_temp = None
model_dir = par["model"]
else:
model_temp = tempfile.TemporaryDirectory()
model_dir = model_temp.name

if zipfile.is_zipfile(par["model"]):
print("Extracting UCE model from .zip...", flush=True)
with zipfile.ZipFile(par["model"], "r") as zip_file:
zip_file.extractall(model_dir)
elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(".tar.gz"):
print("Extracting model from .tar.gz...", flush=True)
with tarfile.open(par["model"], "r:gz") as tar_file:
tar_file.extractall(model_dir)
model_dir = os.path.join(model_dir, os.listdir(model_dir)[0])
else:
raise ValueError(
f"The 'model' argument should be a directory a .zip file or a .tar.gz file"
)

print(f"Model directory: '{model_dir}'", flush=True)

print("Extracting protein embeddings...", flush=True)
with tarfile.open(
os.path.join(model_dir, "protein_embeddings.tar.gz"), "r:gz"
) as tar_file:
tar_file.extractall("./model_files")
protein_embeddings_dir = os.path.join("./model_files", "protein_embeddings")
print(f"Protein embeddings directory: '{protein_embeddings_dir}'", flush=True)

# The following sections implement methods in the UCE.evaluate.AnndataProcessor
# class due to the object not being compatible with the Open Problems setup
model_args = {
"dir": work_dir.name + "/",
"skip": True,
"filter": False, # Turn this off to get embedding for all cells
"name": "input",
"offset_pkl_path": os.path.join(model_dir, "species_offsets.pkl"),
"spec_chrom_csv_path": os.path.join(model_dir, "species_chrom.csv"),
"pe_idx_path": os.path.join(work_dir.name, "input_pe_row_idxs.pt"),
"chroms_path": os.path.join(work_dir.name, "input_chroms.pkl"),
"starts_path": os.path.join(work_dir.name, "input_starts.pkl"),
}

# AnndataProcessor.preprocess_anndata()
print("\n>>> Preprocessing data...", flush=True)
# Set var names to gene symbols
adata.var_names = adata.var["feature_name"]
adata.write_h5ad(os.path.join(model_args["dir"], "input.h5ad"))

row = pd.Series()
row.path = "input.h5ad"
row.covar_col = np.nan
row.species = species

processed_adata, num_cells, num_genes = process_raw_anndata(
row=row,
h5_folder_path=model_args["dir"],
npz_folder_path=model_args["dir"],
scp="",
skip=model_args["skip"],
additional_filter=model_args["filter"],
root=model_args["dir"],
)

# AnndataProcessor.generate_idxs()
print("\n>>> Generating indexes...", flush=True)
species_to_pe = get_species_to_pe(protein_embeddings_dir)
with open(model_args["offset_pkl_path"], "rb") as f:
species_to_offsets = pickle.load(f)
gene_to_chrom_pos = get_spec_chrom_csv(model_args["spec_chrom_csv_path"])
spec_pe_genes = list(species_to_pe[species].keys())
offset = species_to_offsets[species]
pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts(
processed_adata, species, spec_pe_genes, gene_to_chrom_pos, offset
)
torch.save({model_args["name"]: pe_row_idxs}, model_args["pe_idx_path"])
with open(model_args["chroms_path"], "wb+") as f:
pickle.dump({model_args["name"]: dataset_chroms}, f)
with open(model_args["starts_path"], "wb+") as f:
pickle.dump({model_args["name"]: dataset_pos}, f)

# AnndataProcessor.run_evaluation()
print("\n>>> Evaluating model...", flush=True)
model_parameters = Namespace(
token_dim=5120,
d_hid=5120,
nlayers=33, # Small model = 4, full model = 33
output_dim=1280,
multi_gpu=False,
token_file=os.path.join(model_dir, "all_tokens.torch"),
dir=model_args["dir"],
pad_length=1536,
sample_size=1024,
cls_token_idx=3,
CHROM_TOKEN_OFFSET=143574,
chrom_token_right_idx=2,
chrom_token_left_idx=1,
pad_token_idx=0,
)

if model_parameters.nlayers == 4:
model_parameters.model_loc = os.path.join(model_dir, "4layer_model.torch")
model_parameters.batch_size = 100
else:
model_parameters.model_loc = os.path.join(model_dir, "33l_8ep_1024t_1280.torch")
model_parameters.batch_size = 25

accelerator = Accelerator(project_dir=model_args["dir"])
accelerator.wait_for_everyone()
shapes_dict = {model_args["name"]: (num_cells, num_genes)}
run_eval(
adata=processed_adata,
name=model_args["name"],
pe_idx_path=model_args["pe_idx_path"],
chroms_path=model_args["chroms_path"],
starts_path=model_args["starts_path"],
shapes_dict=shapes_dict,
accelerator=accelerator,
args=model_parameters,
)

print("\n>>> Storing output...", flush=True)
embedded_adata = ad.read_h5ad(os.path.join(model_args["dir"], "input_uce_adata.h5ad"))
output = ad.AnnData(
obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": embedded_adata.obsm["X_uce"],
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

print("\n>>> Writing output AnnData to file...", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Cleaning up temporary directories...", flush=True)
work_dir.cleanup()
if model_temp is not None:
model_temp.cleanup()

print("\n>>> Done!", flush=True)
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ dependencies:
- name: methods/scanvi
- name: methods/scimilarity
- name: methods/scvi
- name: methods/uce
# metrics
- name: metrics/asw_batch
- name: metrics/asw_label
Expand Down
5 changes: 4 additions & 1 deletion src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ methods = [
scimilarity.run(
args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")]
),
scvi
scvi,
uce.run(
args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")]
),
]

// construct list of metrics
Expand Down