Skip to content

ponto-n/CSE517_Project_PromptEHR

Repository files navigation

CSE517 Course Project: PromptEHR

Installation

To seperate the packages for this project from other python environments on the system, create a new conda environment. Our scripts were all based on python 3.9:

conda create --name cse517 python=3.9
conda activate cse517

Install the base package requirements:

pip install -r requirements.txt

Install pytorch with pip by following the instructions on the Pytorch website, the command should look something like:

pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116

The PyTrial package should have been installed with the requirements.txt file. PyTrial contains the most up-to-date code for PromptEHR. In case the package wasn't properly installed previously:

pip install pytrial

Full documentation of the group's setup process can be found here.

Preprocessing

More information about converting the CSV files from the MIMIC-III dataset into the files required for training can be found in this markdown file.

Training

To train the model using the hyperparameters from the PromptEHR paper, simply run the training script:

python train.py

Some parameters such as the number of epochs, batch size, number of training samples, and evaluation frequency can be updated by changing the constants definined in the train.py file.

Evaluation

The code to evaluate the perplexity, privacy, and utility of the models is in the evaluate.ipynb notebook file. This file assumes there is a fully trained model in the folder ./model_50_epochs_30k_samples and a partially trained model in the folder ./model_20_epochs_15k_samples. These folders are ignored by git as they are too large to push to the repository.

Computational Requirements:

We trained the model using an NVIDIA V100 GPU on Google Cloud Platform for 5 epochs. The original paper requires 251 GB of RAM to train on the whole dataset with 16 epochs so you might want to think about that before hand on platforms where you can find the right GPUs.


Useful Links

CSE 517 Project Instructions
PromptEHR Paper
PromptEHR GitHub
Reproducibility Report

About

Verifying the results of the PromptEHR paper for the UW CSE517 Natural Language Processing course.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •