Skip to content

Max margin interval trees

Alexandre Drouin edited this page Jul 10, 2017 · 21 revisions

Background

There are few R packages available for interval regression, a machine learning problem which is important in genomics and medicine. Like usual regression, the goal is to learn a function that inputs a feature vector and outputs a real-valued prediction. Unlike usual regression, each output in the training set is an interval of acceptable values (rather than one value). In the terminology of the survival analysis literature, this is regression with “left, right, and interval censored” output/response data.

Max margin interval trees is a new nonlinear model for this problem (TODO: cite paper when published). A dynamic programming algorithm is used to find the optimal split point for each feature. The dynamic programming algorithm has been implemented in C++ and there are wrappers to this solver in R and Python (https://github.com/aldro61/mmit). The Python package includes a decision tree learner. However there is not yet an implementation of the decision tree learner in the R package. The goal of this project is to write an R package that implements the decision tree learner in R.

Related work

The transformation forest model of Hothorn and Zeileis implements a decision tree model which can be trained on censored outputs (https://arxiv.org/abs/1701.02110). The trtf package on R-Forge implements this nonlinear model.

TODO: list of packages that extend partykit. (will be useful for the student to study examples of how to extend partykit)

There are several linear models which can be trained on censored outputs.

Details of your coding project

Implement the Max Margin Interval Tree model in the framework of partykit: mmit R package with

  • mmit() to train a tree model (for a given set of hyper-parameters).
  • cv.mmit() to train a tree model using K-fold cross-validation to select hyper-parameters.
  • mmif() for random forest.
  • boosting?
  • documentation and tests for each function.
  • dev on github with code quality assurance (code coverage and travis for testing).
  • vignette to explain typical package usage.

TODO: Add something about exporting trees to tikz or plots. TODO: Add something about pruning.

Python module

The Python module is organized as follows:

  1. mmit.core: This submodule implements an interface (mmit.core.compute_optimal_costs) to the dynamic programming algorithm (solver) used to find the optimal split point for each feature. The C++ code for the solver is located in this directory.
  2. mmit.learning: This submodule implements the MaxMarginIntervalTree class that allows to learn decision trees and compute predictions. This class is compatible with the Scikit-Learn API (see here). A tree can be learned using the fit method and predictions can be computed using the predict method. The tree learner uses multiple calls to the C++ solver to find best rules to split the tree node.
  3. mmit.metrics: This submodule implements metrics used to measure the accuracy of predictions with respect to the target intervals. The supported metrics are the zero-one loss and the mean squared error.
  4. mmit.model: This class implements the inner workings of the tree models that are learned by the MaxMarginIntervalTree class. The class RegressionTreeNode implements a tree node and is used recursively to construct trees. The TreeExporter class serves to export tree models in various formats. The only format that is currently supported is TikZ/LaTex.
  5. mmit.model_selection: The GridSearchCV class allows to train a MMIT using cross-validation to select the hyper-parameters (see below). Minimum cost-complexity pruning can also be used to choose the optimal size for the tree.
  6. mmit.pruning: This submodule implements minimum cost-complexity pruning (Breiman et al. 1984). Pruning is a regularization method that helps avoid overfitting.
  7. Testing: Some unit tests are implemented under mmit.tests

TODO: describe the equivalent functions in the python module. TODO: add refs

Hyper-parameters

TODO: describe the hyper-parameters of the tree model. (depth, cost-complexity pruning penalty, etc)

Expected impact

This project will provide an R implementation of the max margin interval tree model for interval regression, which currently only has a Python implementation.

Mentors

Students, please contact mentors below after completing at least one of the tests below.

  • Alexandre Drouin <alexandre.drouin.8@ulaval.ca> is a co-author of the Max Margin Interval Trees paper, and author of the Python mmit module and C++ code.
  • Torsten Hothorn <Torsten.Hothorn@r-project.org> is an expert at implementing decision tree algos in R – he is the author of the trtf/partykit packages.
  • Backup mentor: Toby Hocking <toby.hocking@r-project.org> is a co-author of the Max Margin Interval Trees paper, author of the R package penaltyLearning (which implements a linear interval regression algo), and mentor of the students that implemented the iregnet package (GSOC2016-2017).

Tests

Students, please do one or more of the following tests before contacting the mentors above.

  • Easy: run some R code that shows you know how to train and test a decision tree model (rpart, partykit, etc). Bonus points if you can get trtf running for an interval regression problem, for example data(neuroblastomaProcessed, package=”penaltyLearning”).
  • Medium: Read the partykit vignette to learn how to implement a new tree model using the partykit framework. Use it to re-implement a simple version of Breiman’s CART algorithm (rpart R package). Demonstrate the equivalence of your code and rpart on the data set in example(rpart).
  • Hard: Read the help page of the survival::survreg function, which can be used to fit a linear model for censored outputs. Use it as a sub-routine to implement a (slow) regression tree for interval censored output data. Search for the best possible split over all features – the best split is the one that maximizes the logLik of the survreg model. Demonstrate that your regression tree model works on a small subset of data(neuroblastomaProcessed, package=”penaltyLearning”).

Solutions of tests

Students, please post a link to your test results here.

Clone this wiki locally