Code for the paper "Time-to-event pretraining for 3D medical imaging" (ICLR 2025). You can read the paper here
We have provided the code for:
- 💿 Installation
- 📥 Dataset/Model download
- ✍ Tokenization
- 🔖 Labeling
- 🎛 Pretraining with parallel GPUs
- ⚖️ Evaluation with linear probe
- 📚 Tutorial for deriving TTE training loss
- 🔬 Unit Test
You should install the required packages first
conda create -n TTE_ENV python=3.10 -y
conda activate TTE_ENV
pip install -e .
Additionally, for our data preprocessing pipeline we use FEMR (Framework for Electronic Medical Records), a Python package for building deep learning models with EHR data.
You must also have CUDA/cuDNN installed (we recommend CUDA 11.8 and cuDNN 8.7.0).
Note that this currently only works on Linux machines.
pip install --upgrade "jax[cuda11_pip]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
git clone https://github.com/som-shahlab/femr.git
cd femr
pip install -e .
git switch -c femrv2_pub --track origin/femrv2_pub
You should direct to here to download the image modality data in NIfTI format (i.e. file extensions as .nii.gz
)
- The path to this folder will be used as
nii_folder
in below commands
You should direct to here to download EHR modality data in MEDS format (filename meds_omop_inspect.tar.gz
)
- The path to this folder be be used as
parquet_folder
in below commands
You should direct to here to download the model weights from Huggingface
The tokenization process is to organize the EHR code into hierarchical form based on their ontology and then rank them based on entropy and other processing (e.g. normalizing given counts of patients with the code). Then eventually save the tokenizer
The number of pretraining tasks we select is 8,192 and the vocabulary size (total unique codes from EHR) is 65,535. You will need to download ontology from Athena.
Athena is an OHDSI service for downloading ontologies. Simply visit https://athena.ohdsi.org, create an account, and click download at the top and put the ontology in the path in bash file.
Note: the downloaded ontology can be too large (i.e. few hundred GB) so optionally you want to prune it to fit to our dataset to make running substantially faster:
cd src/training
python 1a_prune_ontology.py \
--input-dataset "inspect/timelines_smallfiles_meds/data/*parquet" \
--input-ontology "inspect/ontology.pkl" \
--output-ontology "inspect/inspect_ontology.pkl" \
--num-processors 32
After that you can start training a tokenizer and save it:
./1a_tokenizer.sh
We also provide code examples for deriving the TTE labels of downstream tasks, i.e. a format of a tuple (time_to_event_of_interest (in sec), is_censored
). In the paper we labeled 5 such tasks: ATX (Atelectasis), CMG (Cardiomegaly), CONS (Consolidation), EDM (Edema), and PEFF (Pleural Effusion). However the users can specify their own labeling criteria to do TTE labeling.
Note that our EHR data is under OHDSI common data model so our codes are mainly under SNOMED schema. E.g. these are the codes that we used for labeling:
Task | Code |
---|---|
Pulmonary Hypertension | SNOMED/70995007 |
Pulmonary Embolism | SNOMED/59282003 |
Atelectasis | SNOMED/46621007 |
Cardiomegaly | SNOMED/8186001 |
Consolidation | SNOMED/95436008 |
Edema | SNOMED/267038008 |
Pleural Effusion | SNOMED/60046008 |
You can then proceed to start deriving the TTE labels
cd src/labeling
labeling_functions='tte_mortality' # or 'tte_Pleural_Effusion' etc.
python generate_tte_labels.py \
--index_time_csv_path 'metadata_20250303.csv' \
--index_time_column 'procedure_DATETIME' \
--path_to_database 'femr_extract' \
--path_to_output_dir 'output' \
--labeling_function $labeling_function \
--is_skip_featurize \
--num_threads 12
For pretraining we used 3 model architectures (SWINUNETR/ResNet/DenseNet)
- SWINUNETR's pretrianing weights is from training on 50k public available CT/MRI dataset (weights can be download from here to load in torch)
- ResNet and DenseNet are initialized from inflating 2D weights of pretrained data of ImageNet. The inflation process can be followed by this instructions
- The script to conduct the operations are
src/training/i3d/i3dense.py
- And
src/training/i3d/i3res.py
- The script to conduct the operations are
You can should specify the pretrained tokenizer from above and the dataset path (the parquet
file folder) and image data path (.nii.gz
files folder)
There are other hyperparameter training for the three architecture, you should refer to the hyperparameter table for detailed reference when you input them into the bash script
cd src/training
./1_pretrain_TTE_run_ddp.sh
Each of the architecture would require different training clocktime (or GPU time) with rough estimate.
Architecture | Number of GPUs | Estimated wall-clock time | Estimated GPU hours |
---|---|---|---|
SwinUNETRbase/TTE | 4 H100 (80GB) | 15 days | 1,440 GPU hours |
DenseNet-121base/TTE | 4 A100 (40GB) | 9 days | 864 GPU hours |
ResNet-152base/TTE | 4 A100 (80GB) | 10 days | 960 GPU hours |
Note: optionally you can perform per task fine-tuning but this process is generally expensive given you need to train to completion for any downstream, i.e. num_model * num_tasks
for full paremeter update and this tends not work well (per our fine-tuning table results) but we also provide you script to to do fine-tuning as example
cd src/training/
./2_finetune_run_ddp.sh
After pretraining is done we will perform linear probe (logistic regressin on binary classification tasks, and CoX-PH head of DeepSurv for TTE tasks).
cd src/training
./3_inference_TTE_ddp.sh
We also test on the RSPECT data for the out-of-distribution diagnosis task only evaluation
cd src/training
./3_inference_TTE_RSNA.sh
We also provide guide for deriving tte training loss with exemplar CTs and their corresponding future codes as TTE tasks.
Please refer to notebook at tutorial/pretrain_TTE_tutorial.ipynb
Note:
- This notebook doesn't require GPU to run but just CPU so the speed will slower but it only uses 1 CT as an example
- It still requires all the needed
nii_folder
,parquet_folder
,ontology_path
to derive TTE loss - We reduced the
vocab_size
to 512 andnum_tasks
to 200 to improve speed of getting results - The tutorial will prefit a bias term of the piecewise exponential model layer to avoid collapse without a good initial fit. This will take a few moments
- There's no gradient update or backpropagation, as we are only demonstrating deriving the loss term
We also provide unit test for our model loading, deriving featuring, etc. as preliminary guardrails
Please refer to folder at tests/
Note:
- we mainly provide guardrails for out-of-the-box inference and adaptation
- It loads a model weight (you need to download from above Huggingface repo https://huggingface.co/StanfordShahLab)
- Then user needs to supply labels so that the embeddings can eventually be mapped to it
- It trains a logistic regression given frozen model, and eval
- It tests if the features/labels/embedding match as expected, e.g. the TTE pretrained DenseNet is trained with 1024 dim as feature for downstream linear probe
If you found this work useful, please consider citing it:
@article{huo2024time,
@article{huo2024time,
title={Time-to-Event Pretraining for 3D Medical Imaging},
author={Huo, Zepeng and Fries, Jason Alan and Lozano, Alejandro and Valanarasu, Jeya Maria Jose and Steinberg, Ethan and Blankemeier, Louis and Chaudhari, Akshay S and Langlotz, Curtis and Shah, Nigam H},
journal={arXiv preprint arXiv:2411.09361},
year={2024}
url={https://arxiv.org/abs/2411.09361},
}