Skip to content

som-shahlab/hf_ehr

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Training Long Context Models on EHR Data

This repo contains code and pretrained models for the Context Clues paper. It is designed to enable training any model on HuggingFace on structured EHR data. It comes with Hydra configs + Wandb logging + PyTorch Lightning distributed training support.

It currently supports EHR data defined using the MEDS data standard or FEMR package.

πŸ“– Table of Contents

  1. πŸ€— Pretrained HuggingFace Models
  2. πŸ“€ Installation
  3. πŸš€ Quick Start
  4. πŸ‹οΈβ€β™€οΈ Training
  5. πŸ“Š Evaluation
  6. πŸ’Š MEDS Demo
  7. Ⓜ️ Merative/Truven/MarketScan Demo
  8. πŸ” Profiling
  9. ℹ️ Other
  10. πŸŽ“ Citation

Please see our HuggingFace Collection to download the following models pretrained from scratch on 2 billion tokens of deidentified structured EHR data:

from transformers import AutoModelForCausalLM
from hf_ehr.data.tokenization import CLMBRTokenizer

model = AutoModelForCausalLM.from_pretrained("StanfordShahLab/gpt-base-512-clmbr")
tokenizer = CLMBRTokenizer.from_pretrained("StanfordShahLab/gpt-base-512-clmbr")

All models:

Model Context Lengths
gpt 512, 1024, 2048, 4096
llama 512, 1024, 2048, 4096
mamba 1024, 4096, 8192, 16384
hyena 1024, 4096, 8192, 16384

Here's a quick tutorial on how to use these models directly in your own code (i.e. outside of this repo's infra):

from transformers import AutoModelForCausalLM
from hf_ehr.data.tokenization import CLMBRTokenizer
from hf_ehr.config import Event
from typing import List, Dict
import torch

####################################
# 1. Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("StanfordShahLab/gpt-base-512-clmbr")
tokenizer = CLMBRTokenizer.from_pretrained("StanfordShahLab/gpt-base-512-clmbr")

####################################
# 2. Define patient as sequence of `Event` objects. Only `code` is required.
patient: List[Event] = [
    Event(code='SNOMED/3950001', value=None, unit=None, start=None, end=None, omop_table=None),
    Event(code='Gender/F', value=None, unit=None, start=None, end=None, omop_table=None),
    Event(code='Ethnicity/Hispanic', value=None, unit=None, start=None, end=None, omop_table=None),
    Event(code='SNOMED/609040007', value=None, unit=None, start=None, end=None, omop_table=None),
    Event(code='LOINC/2236-8', value=-3.0, unit=None, start=None, end=None, omop_table=None),
    Event(code='SNOMED/12199005', value=26.3, unit=None, start=None, end=None, omop_table=None),        
]

####################################
# 3. Tokenize patient
batch: Dict[str, torch.Tensor] = tokenizer([ patient ], add_special_tokens=True, return_tensors='pt')
# > batch = {
#     'input_ids': tensor([[ 5, 0, 7, 9, 27, 2049, 6557, 22433, 1]]), 
#     'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 
#     'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])
# }
textual_tokens: List[str] = tokenizer.convert_events_to_tokens(patient)
# > textual_tokens = ['SNOMED/3950001', 'Gender/F', 'Ethnicity/Hispanic', 'SNOMED/609040007', 'LOINC/2236-8 || None || -1.7976931348623157e+308 - 4.0', 'SNOMED/12199005 || None || 26.0 - 28.899999618530273']

####################################
# 4. Run model
outputs = model(**batch, output_hidden_states=True)

####################################
# 5. Get logits + probabilities for next token
logits = outputs.logits
# > logits.shape = torch.Size([1, 9, 39818])
next_token_preds = torch.nn.functional.softmax(logits[:, -1, :], dim=-1) # should sum to 1
# > next_token_pred.shape = torch.Size([1, 39818])

####################################
# 5. Get patient representation for finetuning (usually the hidden state of the LAST layer for the LAST token
last_layer_hidden_state = outputs.hidden_states[-1]
# > last_layer_hidden_state.shape = torch.Size([1, 9, 768])
patient_rep = last_layer_hidden_state[:, -1, :]
# > patient_rep.shape = torch.Size([1, 768])

Direct install:

pip install hf-ehr

For faster Mamba runs, install:

pip install mamba-ssm causal-conv1d

Development install:

conda create -n hf_env python=3.10 -y
conda activate hf_env
pip install -e . --no-cache-dir

# [Optional] If you haven't already created your **Tokenizers**, run the following. If you're on Carina, then skip this step.
python3 hf_ehr/tokenizers/create_clmbr.py # Takes ~5 seconds
python3 hf_ehr/tokenizers/create_desc.py # Takes ~30 min
python3 hf_ehr/tokenizers/create_cookbook.py # Takes many hours

Launch a GPT training run with the ability to configure common hyperparameters (using main.py)

cd hf_ehr/scripts/carina
python3 main.py --model gpt2 --size base --tokenizer clmbr --context_length 512 --dataloader approx --dataset v8 --trainer single_gpu --is_run_local --is_skip_base --extra "callbacks.model_checkpointing.save_most_recent_every_n_train_steps=10"

Launch a Llama run on a MEDS dataset with more customization over configs (using run.py):

cd hf_ehr/scripts/carina
python3 run.py \
    +data=meds_mimic4_demo \
    +trainer=single_gpu \
    +model=llama-base \
    +tokenizer=clmbr \
    data.dataloader.mode=approx \
    data.dataloader.approx_batch_sampler.max_tokens=16384

To launch 4 GPT-base runs on one SLURM node (in parallel), and 4 Mamba runs on another SLURM node (in parallel):

cd hf_ehr/scripts/carina

# GPT runs
sbatch parallel_gpt.sh

# Mamba runs
sbatch parallel_mamba.sh

We use Hydra to manage our configurations and PyTorch Lightning for training.

You can either overwrite the config files in configs/ or pass in CLI arguments to override the defaults.

There are 3 ways to launch a training run.

Easy Mode

Launch multiple runs in parallel on the same SLURM node (each job gets 1 GPU) using hf_ehr/scripts/carina/parallel_{model}.sh:

cd hf_ehr/scripts/carina

# Launch 4 gpt runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_gpt.sh

# Launch 4 bert runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_bert.sh

# Launch 4 hyena runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_hyena.sh

# Launch 4 mamba runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_mamba.sh

Medium Mode

Launch one run on a SLURM node using hf_ehr/scripts/carina/{model}.sh:

cd hf_ehr/scripts/carina

# Launch GPT-2 base model on v8 dataset with CLMBRTokenizer, ApproxBatchSampler dataloader, and 2048 context length; force train from scratch and not resume prior run (even if exists)
python3 main.py --model gpt2 --size base --tokenizer clmbr --context_length 2048 --dataloader approx --dataset v8 --is_force_refresh

# Launch Mamba tiny model on v8 dataset with CookbookTokenizer, ApproxBatchSampler dataloader, and 16384 context length; resume prior run if exists
python3 main.py --model mamba --size tiny --tokenizer cookbook --context_length 16384 --dataloader approx --dataset v8

# Launch BERT-base model on v8 dataset with DescTokenizer, ApproxBatchSampler dataloader, and 4096 context length; resume prior run if exists; overwrite the default device assignment to GPU 1; give wandb run a name of `custom`
python3 main.py --model bert --size base --tokenizer desc --context_length 4096 --dataloader approx --dataset v8 --extra "+trainer.devices=[1] logging.wandb.name=custom"

# Run locally a GPT-2 large model on v8 AllTokens dataset with CLMBRTokenizer, ApproxBatchSampler dataloader, and 1024 context length
python3 main.py --model gpt2 --size large --tokenizer clmbr --context_length 2048 --dataloader approx --dataset v8-alltokens --is_run_local

# Launch Mamba tiny model on v8 dataset with CookbookTokenizer, ApproxBatchSampler dataloader, and 16384 context length; resume prior run if exists; run on 8 H100's
python3 main.py --model mamba --size tiny --tokenizer cookbook --context_length 16384 --dataloader approx --dataset v8 --partitions nigam-h100 --extra "trainer=multi_gpu trainer.devices=[0,1,2,3,4,5,6,7]"

General usage:

python3 main.py --model <model> --size <size> --tokenizer <tokenizer> --context_length <context_length> --dataloader <dataloader> --dataset <dataset> [--extra <extra>] [--partitions <partitions>] [--is_force_refresh] [--is_skip_base] [--is_run_local]

where...

  • <model>: str -- Architecture to use. Choices are gpt, bert, hyena, mamba
  • <size>: str -- Model size to use. Choices are tiny, small, base, medium, large, huge
  • <tokenizer>: str -- Tokenizer to use. Choices are clmbr, desc, cookbook
  • <context_length>: int -- Context length to use
  • <dataloader>: str -- Dataloader to use. Choices are approx, exact
  • <dataset>: str -- Dataset to use. Choices are v8, v8-alltokens, v9, v9-alltokens
  • [--extra <extra>]: Optional[str] -- An optional string that will get appended to the end of the python ../run.py command verbatim
  • [--partitions <partitions>]: Optional[str] -- An optional string that specifies the partitions to use. Defaults to nigam-v100,gpu for gpt2 and BERT, and nigam-h100,nigam-a100 for HYENA and MAMBA
  • [--is_force_refresh]: Optional -- An optional flag that triggers a force refresh of the run (i.e., delete the existing run and start from scratch)
  • [--is_skip_base]: Optional -- An optional flag that skips running source base.sh. Useful when running parallel.sh and we don't want to reinit the conda environment multiple times
  • [--is_run_local]: Optional -- An optional flag that runs the script locally as python run.py instead of as a SLURM sbatch command

Advanced Mode

Directly call run.py, which allows maximum flexibility for configs.

See the Config README for details on all config settings.

cd hf_ehr/scripts/carina

# Launch gpt with: size=base, dataset=v8, context_length=2048, tokenizer=CLMBRTokenizer, sampler=ApproxBatchSampler, max_tokens_per_batch=16384, use_cuda_devices=2,3, wandb_logging_name=gpt2-custom-run, force_restart_existing_run=True, save_to_path=/share/pi/nigam/mwornow/hf_ehr/cache/runs/bert-test/
python3 ../run.py \
    +data=v8 \
    +trainer=single_gpu \
    +model=gpt2-base \
    +tokenizer=clmbr \
    data.dataloader.mode=approx \
    data.dataloader.approx_batch_sampler.max_tokens=16384 \
    data.dataloader.max_length=2048 \
    model.config_kwargs.n_positions=2048 \
    trainer.devices=[2,3] \
    logging.wandb.name=gpt2-custom-run \
    main.is_force_restart=True \
    main.path_to_output_dir=/share/pi/nigam/mwornow/hf_ehr/cache/runs/bert-test/

How to Configure Runs

See the Config README for details on all config settings (models, training, dataloaders, tokenizers, etc.).

EHRSHOT

How to use this repo with EHRSHOT.

1. Generate Patient Representations

This all occurs within the hf_ehr repo.

  1. Identify the path (<path_to_ckpt>) to the model checkpoint you want to evaluate.

  2. Generate patient representations with your model. This will create a folder in /share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/models for this model checkpoint.

cd hf_ehr/scripts/eval/
sbatch ehrshot.sh <path_to_ckpt>

2. Generate EHRSHOT Results

This all occurs within the ehrshot-benchmark repo.

  1. Generate your model's AUROC/AUPRC results by running 7_eval.sh:
# cd to ehrshot-benchmark/ehrshot/bash_scripts/ directory
bash 7_eval.sh --is_use_slurm

3. Generate EHRSHOT Plots

This all occurs within the ehrshot-benchmark repo.

  1. Generate plots by running: 8_make_results_plots.sh. You might need to modify the --model_heads parameter in the file before running to specify what gets included in your plots.
# cd to ehrshot-benchmark/ehrshot/bash_scripts/ directory
bash 8_make_results_plots.sh

We support training and inference on MEDS formatted datasets.

Here is a quick tutorial using the publicly available MIMIC-IV demo dataset (inspired by this tutorial).

  1. Download the MIMIC-IV demo dataset from PhysioNet.
export PATH_TO_DOWNLOAD=mimic4_demo
export PATH_TO_MEDS=meds_mimic4_demo
export PATH_TO_MEDS_READER=meds_mimic4_demo_reader

!wget -q -r -N -c --no-host-directories --cut-dirs=1 -np -P $PATH_TO_DOWNLOAD https://physionet.org/files/mimic-iv-demo/2.2/
  1. Convert the MIMIC-IV demo dataset to MEDS format.
rm -rf $PATH_TO_MEDS 2>/dev/null
meds_etl_mimic $PATH_TO_DOWNLOAD $PATH_TO_MEDS
  1. Convert the MEDS dataset into a MEDS Reader Database (to enable faster data ingestion during training).
rm -rf $PATH_TO_MEDS_READER 2>/dev/null
meds_reader_convert $PATH_TO_MEDS $PATH_TO_MEDS_READER --num_threads 4
  1. Verify everything worked.
meds_reader_verify $PATH_TO_MEDS $PATH_TO_MEDS_READER
  1. Create train/val/test splits (80/10/10) by running the below Python script:
cd hf_ehr/scripts/datasets
python split_meds_dataset.py --path_to_meds_reader $PATH_TO_MEDS_READER --train_split_size 0.8 --val_split_size 0.1
  1. Create a Hydra config for your dataset.
cp hf_ehr/configs/data/meds_mimic4_demo.yaml hf_ehr/configs/data/meds_mimic4_demo_custom.yaml
sed -i 's|/share/pi/nigam/mwornow/mimic-iv-demo-meds-reader|$PATH_TO_MEDS_READER|g' hf_ehr/configs/data/meds_mimic4_demo_custom.yaml
  1. Train a tokenizer on the dataset. Limit our vocabulary to the top-$k$ most frequently occurring codes.
cd hf_ehr/tokenizers
python create_cookbook.py --dataset meds_mimic4_demo --n_procs 5 --chunk_size 10000 --is_force_refresh
python create_cookbook_k.py --dataset meds_mimic4_demo --k 32 --stat count_occurrences
  1. Train a Llama model on the dataset.
  • You need to exchange line 315 in scripts/carina/main.py, with your desired output dir.
  • By default, this uses WandB to track the run, please configure it beforehand by calling wandb init and then changing scripts/run.py at line 294 (and possibly elsewhere) entity and project.
cd hf_ehr/scripts/carina
python3 main.py --model llama --size base --tokenizer clmbr --context_length 1024 --dataloader approx --dataset meds_mimic4_demo_custom --is_run_local --is_force_refresh

We support training and inference on the 2017 Merative MarketScan Commercial Claims and Encounters Database (OMOP CDMv5 formatted) dataset, aka "Truven" or "MarketScan".

  1. Download the Merative OMOP CDMv5 dataset. Note: This takes ~10 mins to download and takes up 347 GB of space.
export PATH_TO_DOWNLOAD=truven-omop
export PATH_TO_MEDS=truven-meds
export PATH_TO_MEDS_READER=truven-meds-reader
gsutil -m cp -r gs://truven_backup/TRUVEN_CDMv5 $PATH_TO_DOWNLOAD
  1. Convert the Truven OMOP CDMv5 dataset to MEDS format. Note: This takes ~4.25 hrs to run and takes up 698MB of space.
meds_etl_omop $PATH_TO_DOWNLOAD $PATH_TO_MEDS
  1. Convert the MEDS dataset into a MEDS Reader Database (to enable faster data ingestion during training). Note: This takes ~15 mins to run and takes up 26GB of space.
meds_reader_convert $PATH_TO_MEDS $PATH_TO_MEDS_READER --num_threads 10
meds_reader_verify $PATH_TO_MEDS $PATH_TO_MEDS_READER
  1. Create train/val/test splits (80/10/10) by running the below Python script. Note: This takes ~1 min to run.
python3 hf_ehr/scripts/datasets/split_meds_dataset.py --path_to_meds_reader $PATH_TO_MEDS_READER --train_split_size 0.8 --val_split_size 0.1
  1. Train a tokenizer on the dataset. Limit our vocabulary to the top-$k$ most frequently occurring codes. TODO
python3 hf_ehr/tokenizers/create_cookbook.py --path_to_dataset_config hf_ehr/configs/data/truven.yaml --path_to_tokenizer_config hf_ehr/configs/tokenizer/truven.yaml --n_procs 32 --chunk_size 10000 --is_force_refresh
python3 hf_ehr/tokenizers/create_cookbook_k.py --path_to_tokenizer_config hf_ehr/configs/tokenizer/truven.yaml --k 32 --stat count_occurrences
  1. Train a Llama model on the dataset using 2 GPUs. Note: This takes ~5 hrs per epoch with 2 H100's.
cd hf_ehr/scripts/carina
python3 main.py --model llama --size base --tokenizer clmbr --context_length 512 --dataloader batch --dataset truven --trainer multi_gpu_2 --is_run_local --is_force_refresh

Run python3 hf_ehr/scripts/huggingface/profile.py to calculate the GPU memory requirements and speed of a forward pass for each model (inference).

H100 (80GB)

Results when testing on one H100 (80GB) GPU.

For a fixed sequence length of 500...

Model GPU Mem (MB) Sequence Length Time per forward pass (s)
StanfordShahLab/gpt-base-512-clmbr 1733.18 500 0.0477076
StanfordShahLab/gpt-base-1024-clmbr 1742.49 500 0.00666604
StanfordShahLab/gpt-base-2048-clmbr 1784.28 500 0.0067502
StanfordShahLab/gpt-base-4096-clmbr 1932.71 500 0.00669129
StanfordShahLab/llama-base-512-clmbr 1257.64 500 0.0132446
StanfordShahLab/llama-base-1024-clmbr 1257.64 500 0.00650663
StanfordShahLab/llama-base-2048-clmbr 1257.64 500 0.00649655
StanfordShahLab/llama-base-4096-clmbr 1257.64 500 0.00647738
StanfordShahLab/mamba-tiny-1024-clmbr 9225.39 500 0.929321
StanfordShahLab/mamba-tiny-4096-clmbr 9225.39 500 0.918106
StanfordShahLab/mamba-tiny-8192-clmbr 9225.39 500 0.913788
StanfordShahLab/mamba-tiny-16384-clmbr 75097.60 500 1.13286
StanfordShahLab/hyena-large-1024-clmbr 1873.65 500 0.157894
StanfordShahLab/hyena-large-4096-clmbr 1898.93 500 0.0143214
StanfordShahLab/hyena-large-8192-clmbr 1900.43 500 0.0147968
StanfordShahLab/hyena-large-16384-clmbr 1903.43 500 0.0146549

For varying sequence lengths to the maximum supported by each model...

Model GPU Mem (MB) Sequence Length Time per forward pass (s)
StanfordShahLab/gpt-base-512-clmbr 1672.24 492 0.00668569
StanfordShahLab/gpt-base-1024-clmbr 2918.64 1004 0.010623
StanfordShahLab/gpt-base-2048-clmbr 5415.06 2028 0.0173378
StanfordShahLab/gpt-base-4096-clmbr 10467.40 4076 0.0358181
StanfordShahLab/llama-base-512-clmbr 1248.85 492 0.00650101
StanfordShahLab/llama-base-1024-clmbr 2003.70 1004 0.00906057
StanfordShahLab/llama-base-2048-clmbr 3477.72 2028 0.0141757
StanfordShahLab/llama-base-4096-clmbr 6495.00 4076 0.026729
StanfordShahLab/mamba-tiny-1024-clmbr 17701.50 1004 1.87138
StanfordShahLab/mamba-tiny-4096-clmbr 70163.30 4076 7.94088
StanfordShahLab/mamba-tiny-8192-clmbr OOM 8172 --
StanfordShahLab/mamba-tiny-16384-clmbr OOM 16364 --
StanfordShahLab/hyena-large-1024-clmbr 3242.75 1004 0.0167934
StanfordShahLab/hyena-large-4096-clmbr 11557.60 4076 0.0334628
StanfordShahLab/hyena-large-8192-clmbr 22618.40 8172 0.183798
StanfordShahLab/hyena-large-16384-clmbr 44731.10 16364 0.11162

When using the pip install mamba-ssm causal-conv1d packages for accelerated Mamba:

Model GPU Mem (MB) Sequence Length Time per forward pass (s)
StanfordShahLab/mamba-tiny-1024-clmbr 1952.03 500 0.0595783
StanfordShahLab/mamba-tiny-4096-clmbr 1952.03 500 0.0205079
StanfordShahLab/mamba-tiny-8192-clmbr 1952.03 500 0.0202962
StanfordShahLab/mamba-tiny-16384-clmbr 1952.03 500 0.0200981
StanfordShahLab/mamba-tiny-1024-clmbr 3271.28 1004 0.0236377
StanfordShahLab/mamba-tiny-4096-clmbr 11728.20 4076 0.0447765
StanfordShahLab/mamba-tiny-8192-clmbr 22923.10 8172 0.0768456
StanfordShahLab/mamba-tiny-16384-clmbr 45357.60 16364 0.137661

Based

To get the based model to run, you need to do the following installations on an A100 or above node:

pip install -v \
    --disable-pip-version-check \
    --no-cache-dir \
    --no-build-isolation \
    --config-settings "--build-option=--cpp_ext" \
    --config-settings "--build-option=--cuda_ext" \
    'git+https://github.com/NVIDIA/apex@b496d85'  --no-cache-dir

pip install --no-cache-dir \
    torch==2.1.2 \
    torchvision==0.16.2 \
    torchaudio==2.1.2 \
    --index-url https://download.pytorch.org/whl/cu118 --no-cache-dir

# Install FLA triton kernel
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2' --no-build-isolation --no-cache-dir
pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/fused_dense_lib'  --no-build-isolation --no-cache-dir
pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/layer_norm' --no-build-isolation --no-cache-dir

git clone git@github.com:HazyResearch/based.git
cd based
pip install -e . --no-cache-dir

πŸ€– Creating a Model

Let's say we want to create a new model called {model} of size {size}.

  1. Create the Hydra config YAML for your model architecture in hf_ehr/configs/architecture/{model}.yaml. Copy the contents of hf_ehr/configs/architecture/bert.yaml and modify as needed.

  2. Create the Hydra config YAML for your model instantiation in hf_ehr/configs/models/{model}-{size}.yaml. Copy the contents of hf_ehr/configs/models/bert-base.yaml and modify as needed.

  3. Create the model itself by creating a new file hf_ehr/models/{model}.py. Copy the contents of models/bert.py and modify as needed.

  4. Add your model to hf_ehr/scripts/run.py above the line raise ValueError(f"Model {config.model.name} not supported.")

βœ‚οΈ Creating a Tokenizer

See the Tokenizer README for details on creating tokenizers and how they are stored on the file system.

πŸ€— Uploading a Model to Hugging Face

See the Hugging Face README for details on uploading models to Hugging Face.

git add . && git commit -m "New version"
make release

To remove an existing tag and create a new one:

git tag -d v0.1.4
git push origin --delete v0.1.4
make release

MEDS-DEV

First, create a tokenizer from the MEDS extract. This takes 834 seconds.

cd hf_ehr/tokenizers
python create_cookbook.py --dataset meds_dev --n_procs 5 --chunk_size 10000 --is_force_refresh

πŸŽ“ Citation

If you found this work useful, please consider citing it:

@article{wornow2024contextclues,
      title={Context Clues: Evaluating Long Context Models for Clinical Prediction Tasks on EHRs}, 
      author={Michael Wornow and Suhana Bedi and Miguel Angel Fuentes Hernandez and Ethan Steinberg and Jason Alan Fries and Christopher RΓ© and Sanmi Koyejo and Nigam H. Shah},
      year={2024},
      eprint={2412.16178},
      url={https://arxiv.org/abs/2412.16178}, 
}