|
| 1 | +//! # Extra Trees Regressor |
| 2 | +//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized |
| 3 | +//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting. |
| 4 | +//! |
| 5 | +//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can |
| 6 | +//! reduce the variance of the model and often make the training process faster. |
| 7 | +//! |
| 8 | +//! The two key differences from a standard Random Forest are: |
| 9 | +//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples. |
| 10 | +//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one. |
| 11 | +//! |
| 12 | +//! See [ensemble models](../index.html) for more details. |
| 13 | +//! |
| 14 | +//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time. |
| 15 | +//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors. |
| 16 | +//! |
| 17 | +//! Example: |
| 18 | +//! |
| 19 | +//! ``` |
| 20 | +//! use smartcore::linalg::basic::matrix::DenseMatrix; |
| 21 | +//! use smartcore::ensemble::extra_trees_regressor::*; |
| 22 | +//! |
| 23 | +//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html)) |
| 24 | +//! let x = DenseMatrix::from_2d_array(&[ |
| 25 | +//! &[234.289, 235.6, 159., 107.608, 1947., 60.323], |
| 26 | +//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], |
| 27 | +//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], |
| 28 | +//! &[284.599, 335.1, 165., 110.929, 1950., 61.187], |
| 29 | +//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], |
| 30 | +//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639], |
| 31 | +//! &[365.385, 187., 354.7, 115.094, 1953., 64.989], |
| 32 | +//! &[363.112, 357.8, 335., 116.219, 1954., 63.761], |
| 33 | +//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], |
| 34 | +//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857], |
| 35 | +//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], |
| 36 | +//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513], |
| 37 | +//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], |
| 38 | +//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], |
| 39 | +//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], |
| 40 | +//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], |
| 41 | +//! ]).unwrap(); |
| 42 | +//! let y = vec![ |
| 43 | +//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, |
| 44 | +//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9 |
| 45 | +//! ]; |
| 46 | +//! |
| 47 | +//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap(); |
| 48 | +//! |
| 49 | +//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction |
| 50 | +//! ``` |
| 51 | +//! |
| 52 | +//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> |
| 53 | +//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> |
| 54 | +
|
| 55 | +use std::default::Default; |
| 56 | +use std::fmt::Debug; |
| 57 | + |
| 58 | +#[cfg(feature = "serde")] |
| 59 | +use serde::{Deserialize, Serialize}; |
| 60 | + |
| 61 | +use crate::api::{Predictor, SupervisedEstimator}; |
| 62 | +use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters}; |
| 63 | +use crate::error::Failed; |
| 64 | +use crate::linalg::basic::arrays::{Array1, Array2}; |
| 65 | +use crate::numbers::basenum::Number; |
| 66 | +use crate::numbers::floatnum::FloatNumber; |
| 67 | +use crate::tree::base_tree_regressor::Splitter; |
| 68 | + |
| 69 | +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] |
| 70 | +#[derive(Debug, Clone)] |
| 71 | +/// Parameters of the Extra Trees Regressor |
| 72 | +/// Some parameters here are passed directly into base estimator. |
| 73 | +pub struct ExtraTreesRegressorParameters { |
| 74 | + #[cfg_attr(feature = "serde", serde(default))] |
| 75 | + /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) |
| 76 | + pub max_depth: Option<u16>, |
| 77 | + #[cfg_attr(feature = "serde", serde(default))] |
| 78 | + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) |
| 79 | + pub min_samples_leaf: usize, |
| 80 | + #[cfg_attr(feature = "serde", serde(default))] |
| 81 | + /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) |
| 82 | + pub min_samples_split: usize, |
| 83 | + #[cfg_attr(feature = "serde", serde(default))] |
| 84 | + /// The number of trees in the forest. |
| 85 | + pub n_trees: usize, |
| 86 | + #[cfg_attr(feature = "serde", serde(default))] |
| 87 | + /// Number of random sample of predictors to use as split candidates. |
| 88 | + pub m: Option<usize>, |
| 89 | + #[cfg_attr(feature = "serde", serde(default))] |
| 90 | + /// Whether to keep samples used for tree generation. This is required for OOB prediction. |
| 91 | + pub keep_samples: bool, |
| 92 | + #[cfg_attr(feature = "serde", serde(default))] |
| 93 | + /// Seed used for bootstrap sampling and feature selection for each tree. |
| 94 | + pub seed: u64, |
| 95 | +} |
| 96 | + |
| 97 | +/// Extra Trees Regressor |
| 98 | +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] |
| 99 | +#[derive(Debug)] |
| 100 | +pub struct ExtraTreesRegressor< |
| 101 | + TX: Number + FloatNumber + PartialOrd, |
| 102 | + TY: Number, |
| 103 | + X: Array2<TX>, |
| 104 | + Y: Array1<TY>, |
| 105 | +> { |
| 106 | + forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>, |
| 107 | +} |
| 108 | + |
| 109 | +impl ExtraTreesRegressorParameters { |
| 110 | + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) |
| 111 | + pub fn with_max_depth(mut self, max_depth: u16) -> Self { |
| 112 | + self.max_depth = Some(max_depth); |
| 113 | + self |
| 114 | + } |
| 115 | + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) |
| 116 | + pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self { |
| 117 | + self.min_samples_leaf = min_samples_leaf; |
| 118 | + self |
| 119 | + } |
| 120 | + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) |
| 121 | + pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self { |
| 122 | + self.min_samples_split = min_samples_split; |
| 123 | + self |
| 124 | + } |
| 125 | + /// The number of trees in the forest. |
| 126 | + pub fn with_n_trees(mut self, n_trees: usize) -> Self { |
| 127 | + self.n_trees = n_trees; |
| 128 | + self |
| 129 | + } |
| 130 | + /// Number of random sample of predictors to use as split candidates. |
| 131 | + pub fn with_m(mut self, m: usize) -> Self { |
| 132 | + self.m = Some(m); |
| 133 | + self |
| 134 | + } |
| 135 | + |
| 136 | + /// Whether to keep samples used for tree generation. This is required for OOB prediction. |
| 137 | + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { |
| 138 | + self.keep_samples = keep_samples; |
| 139 | + self |
| 140 | + } |
| 141 | + |
| 142 | + /// Seed used for bootstrap sampling and feature selection for each tree. |
| 143 | + pub fn with_seed(mut self, seed: u64) -> Self { |
| 144 | + self.seed = seed; |
| 145 | + self |
| 146 | + } |
| 147 | +} |
| 148 | +impl Default for ExtraTreesRegressorParameters { |
| 149 | + fn default() -> Self { |
| 150 | + ExtraTreesRegressorParameters { |
| 151 | + max_depth: Option::None, |
| 152 | + min_samples_leaf: 1, |
| 153 | + min_samples_split: 2, |
| 154 | + n_trees: 10, |
| 155 | + m: Option::None, |
| 156 | + keep_samples: false, |
| 157 | + seed: 0, |
| 158 | + } |
| 159 | + } |
| 160 | +} |
| 161 | + |
| 162 | +impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> |
| 163 | + SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y> |
| 164 | +{ |
| 165 | + fn new() -> Self { |
| 166 | + Self { |
| 167 | + forest_regressor: Option::None, |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> { |
| 172 | + ExtraTreesRegressor::fit(x, y, parameters) |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> |
| 177 | + Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y> |
| 178 | +{ |
| 179 | + fn predict(&self, x: &X) -> Result<Y, Failed> { |
| 180 | + self.predict(x) |
| 181 | + } |
| 182 | +} |
| 183 | + |
| 184 | +impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> |
| 185 | + ExtraTreesRegressor<TX, TY, X, Y> |
| 186 | +{ |
| 187 | + /// Build a forest of trees from the training set. |
| 188 | + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. |
| 189 | + /// * `y` - the target class values |
| 190 | + pub fn fit( |
| 191 | + x: &X, |
| 192 | + y: &Y, |
| 193 | + parameters: ExtraTreesRegressorParameters, |
| 194 | + ) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> { |
| 195 | + let regressor_params = BaseForestRegressorParameters { |
| 196 | + max_depth: parameters.max_depth, |
| 197 | + min_samples_leaf: parameters.min_samples_leaf, |
| 198 | + min_samples_split: parameters.min_samples_split, |
| 199 | + n_trees: parameters.n_trees, |
| 200 | + m: parameters.m, |
| 201 | + keep_samples: parameters.keep_samples, |
| 202 | + seed: parameters.seed, |
| 203 | + bootstrap: false, |
| 204 | + splitter: Splitter::Random, |
| 205 | + }; |
| 206 | + let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?; |
| 207 | + |
| 208 | + Ok(ExtraTreesRegressor { |
| 209 | + forest_regressor: Some(forest_regressor), |
| 210 | + }) |
| 211 | + } |
| 212 | + |
| 213 | + /// Predict class for `x` |
| 214 | + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. |
| 215 | + pub fn predict(&self, x: &X) -> Result<Y, Failed> { |
| 216 | + let forest_regressor = self.forest_regressor.as_ref().unwrap(); |
| 217 | + forest_regressor.predict(x) |
| 218 | + } |
| 219 | + |
| 220 | + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. |
| 221 | + pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> { |
| 222 | + let forest_regressor = self.forest_regressor.as_ref().unwrap(); |
| 223 | + forest_regressor.predict_oob(x) |
| 224 | + } |
| 225 | +} |
| 226 | + |
| 227 | +#[cfg(test)] |
| 228 | +mod tests { |
| 229 | + use super::*; |
| 230 | + use crate::linalg::basic::matrix::DenseMatrix; |
| 231 | + use crate::metrics::mean_squared_error; |
| 232 | + |
| 233 | + #[test] |
| 234 | + fn test_extra_trees_regressor_fit_predict() { |
| 235 | + // Use a simpler, more predictable dataset for unit testing. |
| 236 | + let x = DenseMatrix::from_2d_array(&[ |
| 237 | + &[1., 2.], |
| 238 | + &[3., 4.], |
| 239 | + &[5., 6.], |
| 240 | + &[7., 8.], |
| 241 | + &[9., 10.], |
| 242 | + &[11., 12.], |
| 243 | + &[13., 14.], |
| 244 | + &[15., 16.], |
| 245 | + ]) |
| 246 | + .unwrap(); |
| 247 | + let y = vec![1., 2., 3., 4., 5., 6., 7., 8.]; |
| 248 | + |
| 249 | + let parameters = ExtraTreesRegressorParameters::default() |
| 250 | + .with_n_trees(100) |
| 251 | + .with_seed(42); |
| 252 | + |
| 253 | + let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap(); |
| 254 | + let y_hat = regressor.predict(&x).unwrap(); |
| 255 | + |
| 256 | + assert_eq!(y_hat.len(), y.len()); |
| 257 | + // A basic check to ensure the model is learning something. |
| 258 | + // The error should be significantly less than the variance of y. |
| 259 | + let mse = mean_squared_error(&y, &y_hat); |
| 260 | + // With this simple dataset, the error should be very low. |
| 261 | + assert!(mse < 1.0); |
| 262 | + } |
| 263 | + |
| 264 | + #[test] |
| 265 | + fn test_fit_predict_higher_dims() { |
| 266 | + // Dataset with 10 features, but y is only dependent on the 3rd feature (index 2). |
| 267 | + let x = DenseMatrix::from_2d_array(&[ |
| 268 | + // The 3rd column is the important one. The rest are noise. |
| 269 | + &[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.], |
| 270 | + &[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.], |
| 271 | + &[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.], |
| 272 | + &[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.], |
| 273 | + &[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.], |
| 274 | + &[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.], |
| 275 | + ]) |
| 276 | + .unwrap(); |
| 277 | + let y = vec![10., 20., 30., 40., 55., 65.]; |
| 278 | + |
| 279 | + let parameters = ExtraTreesRegressorParameters::default() |
| 280 | + .with_n_trees(100) |
| 281 | + .with_seed(42); |
| 282 | + |
| 283 | + let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap(); |
| 284 | + let y_hat = regressor.predict(&x).unwrap(); |
| 285 | + |
| 286 | + assert_eq!(y_hat.len(), y.len()); |
| 287 | + |
| 288 | + let mse = mean_squared_error(&y, &y_hat); |
| 289 | + |
| 290 | + // The model should be able to learn this simple relationship perfectly, |
| 291 | + // ignoring the noise features. The MSE should be very low. |
| 292 | + assert!(mse < 1.0); |
| 293 | + } |
| 294 | + |
| 295 | + #[test] |
| 296 | + fn test_reproducibility() { |
| 297 | + let x = DenseMatrix::from_2d_array(&[ |
| 298 | + &[1., 2.], |
| 299 | + &[3., 4.], |
| 300 | + &[5., 6.], |
| 301 | + &[7., 8.], |
| 302 | + &[9., 10.], |
| 303 | + &[11., 12.], |
| 304 | + ]) |
| 305 | + .unwrap(); |
| 306 | + let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; |
| 307 | + |
| 308 | + let params = ExtraTreesRegressorParameters::default().with_seed(42); |
| 309 | + |
| 310 | + let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap(); |
| 311 | + let y_hat1 = regressor1.predict(&x).unwrap(); |
| 312 | + |
| 313 | + let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap(); |
| 314 | + let y_hat2 = regressor2.predict(&x).unwrap(); |
| 315 | + |
| 316 | + assert_eq!(y_hat1, y_hat2); |
| 317 | + } |
| 318 | +} |
0 commit comments