Skip to content

v1xerunt/RHealth

Repository files navigation

RHealth

docs

RHealth is an open-source R package designed to bring a comprehensive deep learning toolkit to the R community for healthcare predictive modeling. It provides an accessible, integrated environment for R users to build, train, and evaluate complex models on EHR data. This package is the R counterpart to the popular Python library PyHealth.

RHealth is funded by the ISC grant from the R Consortium.

The detailed documentations are at RHealth Documentation


✍️ Citing RHealth

If you use RHealth in your research, please cite our work:

@misc{RHealth2025,
  author = {Ji Song, Zhixia Ren, Zhenbang Wu, John Wu, Chaoqi Yang, Jimeng Sun, Liantao Ma, Ewen M Harrison, and Junyi Gao},
  title = {RHealth: A Deep Learning Toolkit for Healthcare Predictive Modeling},
  year = {2025},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{[https://github.com/v1xerunt/RHealth](https://github.com/v1xerunt/RHealth)}}
}

📥 Installation

Install the development version from GitHub:

# install.packages("pak")
pak::pak("v1xerunt/RHealth")

# Or using devtools:
# install.packages("devtools")
devtools::install_github("v1xerunt/RHealth")

Once installed, load the package to access its functionalities:

library(RHealth)

🧱 Modules

RHealth is organized into several powerful, interconnected modules.

⚕️ 1. Medical Code Module

This standalone module helps map medical codes between different systems.

Code Lookup:

lookup_code(code = "428.0", system = "ICD9CM")

Find Ancestors/Descendants:

# Get all parent codes
get_ancestors(code = "428.22", system = "ICD9CM")

# Get all child codes
get_descendants(code = "428", system = "ICD9CM")

Cross-System Mapping:

# Map from ICD-9 to CCS
map_code(code = "428.0", from = "ICD9CM", to = "CCSCM")
# Map from ICD-9 to ICD-10
map_code(code = "589", from = "ICD9CM", to = "ICD10CM")

🛢️ 2. Dataset Module

The Dataset module is the foundation of RHealth. It transforms raw, multi-table Electronic Health Record (EHR) dumps into tidy, task-ready tensors that any downstream model can consume.

Key Features:

  • Data Harmonisation: Merges heterogeneous tables into a single, canonical event table.
  • Built-in Caching: Uses DuckDB for CSV → Parquet caching, enabling up to 10x faster reloads.
  • Dev Mode: Allows for lightning-fast iteration by using a small subset of patients.

You can download a sample dataset (MIMIC-IV Demo, version 2.2) directly from PhysioNet using the following link:

👉 https://physionet.org/content/mimic-iv-demo/2.2/#files-panel

Quick Start:

Define a dataset from your source files using a YAML configuration.

# The YAML config defines tables, patient IDs, timestamps, and attributes
# See the full documentation for details on the YAML structure.

# Load the dataset
data_dir <- "/Users/yourname/datasets/mimiciv/"

ds <- MIMIC4EHRDataset$new(
  root = data_dir,
  tables = c("patients", "admissions", "diagnoses_icd", "procedures_icd", "prescriptions"),
  dataset_name = "mimic4_ehr",
  dev = TRUE
)

ds$stats()
#> Dataset : mimic4_ehr
#> Dev mode : TRUE
#> Patients : 1 000
#> Events   : 2 187 540

🎯 3. Task Module

The Task module defines the prediction problem. It tells RHealth what to predict, which data to use, and how to generate (input, target) samples from a patient’s event timeline.

A task is defined by subclassing BaseTask and implementing the call() method.

Example Task Definition:

MyReadmissionTask <- R6::R6Class(
  "MyReadmissionTask",
  inherit = BaseTask,
  public = list(
    initialize = function() {
      super$initialize(
        task_name     = "MyReadmissionTask",
        input_schema  = list(diagnoses = "sequence", procedures = "sequence"),
        output_schema = list(outcome = "binary")
      )
    },
    call = function(patient) {
      # Your logic to generate samples for a single patient...
      # This should return a list or list-of-lists with named fields
      # matching the input/output schemas.
      
      # Example:
      # list(
      #   diagnoses = c("401.9", "250.00"),
      #   procedures = c("88.72"),
      #   outcome = 1
      # )
    }
  )
)

Generating Samples:

Once a task is defined, use it with your dataset to create a SampleDataset compatible with {torch}.

task    <- InHospitalMortalityMIMIC4$new() # A built-in task
samples <- ds$set_task(task)

🧠 4. Model Module

The Model module provides ready-to-use neural network architectures. All models inherit from a BaseModel, which automates dimension calculation, loss function selection, and device management (CPU/GPU).

Built-in Models:

RHealth includes reference implementations like RNN, which can be instantiated in one line:

model <- RNN(
  dataset       = samples, # The SampleDataset from set_task()
  embedding_dim = 128,
  hidden_dim    = 128
)

Custom Models:

You can easily write your own model by inheriting from BaseModel.

MyDenseNet <- torch::nn_module(
  "MyDenseNet",
  inherit = BaseModel,
  initialize = function(dataset, hidden_dim = 256) {
    super$initialize(dataset) # IMPORTANT: handles schema setup
    
    # Calculate input/output dimensions automatically
    in_dim  <- sum(purrr::map_int(dataset$input_processors, "size"))
    out_dim <- self$get_output_size()

    self$fc1 <- nn_linear(in_dim, hidden_dim)
    self$fc2 <- nn_linear(hidden_dim, out_dim)
  },
  forward = function(inputs) {
    # Flatten and concatenate all input features
    x <- torch::torch_cat(purrr::flatten(inputs), dim = 2)
    logits <- self$fc2(torch_relu(self$fc1(x)))

    # Return loss and probabilities
    list(
      loss   = self$get_loss_function()(logits, inputs[[self$label_keys]]),
      y_prob = self$prepare_y_prob(logits)
    )
  }
)

💪 5. Trainer Module

The Trainer module provides a high-level, configurable training loop that handles logging, checkpointing, evaluation, and progress bars.

Example Training Workflow:

# 1. Create data loaders
splits <- split_by_patient(samples, c(0.8, 0.1, 0.1), stratify = TRUE, stratify_by = 'mortality')
train_dl <- get_dataloader(splits[[1]], batch_size = 32, shuffle = TRUE)
val_dl <- get_dataloader(splits[[2]], batch_size = 32)
test_dl <- get_dataloader(splits[[3]], batch_size = 32)

# 2. Instantiate a model
model <- RNN(train_dl, embedding_dim = 128, hidden_dim = 128)

# 3. Set up the trainer
trainer <- Trainer$new(
    model,
    metrics     = c("auroc", "auprc"),
    output_path = "experiments",
    exp_name    = "mortality_rnn"
)

# 4. Start training
trainer$train(
  train_dataloader = train_dl,
  val_dataloader = val_dl,
  epochs = 10,
  optimizer_params = list(lr = 1e-3),
  monitor = "roc_auc"
)

Logs and model checkpoints (best.ckpt, last.ckpt) are saved automatically to experiments/mortality_rnn/.


🚀 Future Plans

RHealth is under active development. Our roadmap includes:

  • Healthcare DL Core Module: Adding more SoTA models like RETAIN, AdaCare, and Transformers.
  • Prediction Task Module: Adding built-in tasks for common clinical predictions (e.g., length of stay, readmission risk).
  • Multi-modal Data: Enhancing support for integrating imaging, genomics, and clinical notes.
  • LLM Integration: Augmenting package capabilities with Large Language Models.

We welcome feedback and contributions! Please submit an issue on GitHub or contact the maintainers.

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages