Skip to content

Commit 4841791

Browse files
authored
implemented extra trees (#320)
* implemented extra trees * implemented extra trees
1 parent 9fef05e commit 4841791

File tree

2 files changed

+319
-0
lines changed

2 files changed

+319
-0
lines changed

src/ensemble/extra_trees_regressor.rs

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
//! # Extra Trees Regressor
2+
//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized
3+
//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting.
4+
//!
5+
//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can
6+
//! reduce the variance of the model and often make the training process faster.
7+
//!
8+
//! The two key differences from a standard Random Forest are:
9+
//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples.
10+
//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one.
11+
//!
12+
//! See [ensemble models](../index.html) for more details.
13+
//!
14+
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
15+
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
16+
//!
17+
//! Example:
18+
//!
19+
//! ```
20+
//! use smartcore::linalg::basic::matrix::DenseMatrix;
21+
//! use smartcore::ensemble::extra_trees_regressor::*;
22+
//!
23+
//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html))
24+
//! let x = DenseMatrix::from_2d_array(&[
25+
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
26+
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
27+
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
28+
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
29+
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
30+
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
31+
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
32+
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
33+
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
34+
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
35+
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
36+
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
37+
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
38+
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
39+
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
40+
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
41+
//! ]).unwrap();
42+
//! let y = vec![
43+
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
44+
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
45+
//! ];
46+
//!
47+
//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap();
48+
//!
49+
//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
50+
//! ```
51+
//!
52+
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
53+
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
54+
55+
use std::default::Default;
56+
use std::fmt::Debug;
57+
58+
#[cfg(feature = "serde")]
59+
use serde::{Deserialize, Serialize};
60+
61+
use crate::api::{Predictor, SupervisedEstimator};
62+
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
63+
use crate::error::Failed;
64+
use crate::linalg::basic::arrays::{Array1, Array2};
65+
use crate::numbers::basenum::Number;
66+
use crate::numbers::floatnum::FloatNumber;
67+
use crate::tree::base_tree_regressor::Splitter;
68+
69+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
70+
#[derive(Debug, Clone)]
71+
/// Parameters of the Extra Trees Regressor
72+
/// Some parameters here are passed directly into base estimator.
73+
pub struct ExtraTreesRegressorParameters {
74+
#[cfg_attr(feature = "serde", serde(default))]
75+
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
76+
pub max_depth: Option<u16>,
77+
#[cfg_attr(feature = "serde", serde(default))]
78+
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
79+
pub min_samples_leaf: usize,
80+
#[cfg_attr(feature = "serde", serde(default))]
81+
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
82+
pub min_samples_split: usize,
83+
#[cfg_attr(feature = "serde", serde(default))]
84+
/// The number of trees in the forest.
85+
pub n_trees: usize,
86+
#[cfg_attr(feature = "serde", serde(default))]
87+
/// Number of random sample of predictors to use as split candidates.
88+
pub m: Option<usize>,
89+
#[cfg_attr(feature = "serde", serde(default))]
90+
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
91+
pub keep_samples: bool,
92+
#[cfg_attr(feature = "serde", serde(default))]
93+
/// Seed used for bootstrap sampling and feature selection for each tree.
94+
pub seed: u64,
95+
}
96+
97+
/// Extra Trees Regressor
98+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
99+
#[derive(Debug)]
100+
pub struct ExtraTreesRegressor<
101+
TX: Number + FloatNumber + PartialOrd,
102+
TY: Number,
103+
X: Array2<TX>,
104+
Y: Array1<TY>,
105+
> {
106+
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
107+
}
108+
109+
impl ExtraTreesRegressorParameters {
110+
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
111+
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
112+
self.max_depth = Some(max_depth);
113+
self
114+
}
115+
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
116+
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
117+
self.min_samples_leaf = min_samples_leaf;
118+
self
119+
}
120+
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
121+
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
122+
self.min_samples_split = min_samples_split;
123+
self
124+
}
125+
/// The number of trees in the forest.
126+
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
127+
self.n_trees = n_trees;
128+
self
129+
}
130+
/// Number of random sample of predictors to use as split candidates.
131+
pub fn with_m(mut self, m: usize) -> Self {
132+
self.m = Some(m);
133+
self
134+
}
135+
136+
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
137+
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
138+
self.keep_samples = keep_samples;
139+
self
140+
}
141+
142+
/// Seed used for bootstrap sampling and feature selection for each tree.
143+
pub fn with_seed(mut self, seed: u64) -> Self {
144+
self.seed = seed;
145+
self
146+
}
147+
}
148+
impl Default for ExtraTreesRegressorParameters {
149+
fn default() -> Self {
150+
ExtraTreesRegressorParameters {
151+
max_depth: Option::None,
152+
min_samples_leaf: 1,
153+
min_samples_split: 2,
154+
n_trees: 10,
155+
m: Option::None,
156+
keep_samples: false,
157+
seed: 0,
158+
}
159+
}
160+
}
161+
162+
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
163+
SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y>
164+
{
165+
fn new() -> Self {
166+
Self {
167+
forest_regressor: Option::None,
168+
}
169+
}
170+
171+
fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> {
172+
ExtraTreesRegressor::fit(x, y, parameters)
173+
}
174+
}
175+
176+
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
177+
Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y>
178+
{
179+
fn predict(&self, x: &X) -> Result<Y, Failed> {
180+
self.predict(x)
181+
}
182+
}
183+
184+
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
185+
ExtraTreesRegressor<TX, TY, X, Y>
186+
{
187+
/// Build a forest of trees from the training set.
188+
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
189+
/// * `y` - the target class values
190+
pub fn fit(
191+
x: &X,
192+
y: &Y,
193+
parameters: ExtraTreesRegressorParameters,
194+
) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> {
195+
let regressor_params = BaseForestRegressorParameters {
196+
max_depth: parameters.max_depth,
197+
min_samples_leaf: parameters.min_samples_leaf,
198+
min_samples_split: parameters.min_samples_split,
199+
n_trees: parameters.n_trees,
200+
m: parameters.m,
201+
keep_samples: parameters.keep_samples,
202+
seed: parameters.seed,
203+
bootstrap: false,
204+
splitter: Splitter::Random,
205+
};
206+
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
207+
208+
Ok(ExtraTreesRegressor {
209+
forest_regressor: Some(forest_regressor),
210+
})
211+
}
212+
213+
/// Predict class for `x`
214+
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
215+
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
216+
let forest_regressor = self.forest_regressor.as_ref().unwrap();
217+
forest_regressor.predict(x)
218+
}
219+
220+
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
221+
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
222+
let forest_regressor = self.forest_regressor.as_ref().unwrap();
223+
forest_regressor.predict_oob(x)
224+
}
225+
}
226+
227+
#[cfg(test)]
228+
mod tests {
229+
use super::*;
230+
use crate::linalg::basic::matrix::DenseMatrix;
231+
use crate::metrics::mean_squared_error;
232+
233+
#[test]
234+
fn test_extra_trees_regressor_fit_predict() {
235+
// Use a simpler, more predictable dataset for unit testing.
236+
let x = DenseMatrix::from_2d_array(&[
237+
&[1., 2.],
238+
&[3., 4.],
239+
&[5., 6.],
240+
&[7., 8.],
241+
&[9., 10.],
242+
&[11., 12.],
243+
&[13., 14.],
244+
&[15., 16.],
245+
])
246+
.unwrap();
247+
let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];
248+
249+
let parameters = ExtraTreesRegressorParameters::default()
250+
.with_n_trees(100)
251+
.with_seed(42);
252+
253+
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
254+
let y_hat = regressor.predict(&x).unwrap();
255+
256+
assert_eq!(y_hat.len(), y.len());
257+
// A basic check to ensure the model is learning something.
258+
// The error should be significantly less than the variance of y.
259+
let mse = mean_squared_error(&y, &y_hat);
260+
// With this simple dataset, the error should be very low.
261+
assert!(mse < 1.0);
262+
}
263+
264+
#[test]
265+
fn test_fit_predict_higher_dims() {
266+
// Dataset with 10 features, but y is only dependent on the 3rd feature (index 2).
267+
let x = DenseMatrix::from_2d_array(&[
268+
// The 3rd column is the important one. The rest are noise.
269+
&[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
270+
&[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
271+
&[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
272+
&[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
273+
&[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
274+
&[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
275+
])
276+
.unwrap();
277+
let y = vec![10., 20., 30., 40., 55., 65.];
278+
279+
let parameters = ExtraTreesRegressorParameters::default()
280+
.with_n_trees(100)
281+
.with_seed(42);
282+
283+
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
284+
let y_hat = regressor.predict(&x).unwrap();
285+
286+
assert_eq!(y_hat.len(), y.len());
287+
288+
let mse = mean_squared_error(&y, &y_hat);
289+
290+
// The model should be able to learn this simple relationship perfectly,
291+
// ignoring the noise features. The MSE should be very low.
292+
assert!(mse < 1.0);
293+
}
294+
295+
#[test]
296+
fn test_reproducibility() {
297+
let x = DenseMatrix::from_2d_array(&[
298+
&[1., 2.],
299+
&[3., 4.],
300+
&[5., 6.],
301+
&[7., 8.],
302+
&[9., 10.],
303+
&[11., 12.],
304+
])
305+
.unwrap();
306+
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
307+
308+
let params = ExtraTreesRegressorParameters::default().with_seed(42);
309+
310+
let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
311+
let y_hat1 = regressor1.predict(&x).unwrap();
312+
313+
let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
314+
let y_hat2 = regressor2.predict(&x).unwrap();
315+
316+
assert_eq!(y_hat1, y_hat2);
317+
}
318+
}

src/ensemble/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
1818
1919
mod base_forest_regressor;
20+
pub mod extra_trees_regressor;
2021
/// Random forest classifier
2122
pub mod random_forest_classifier;
2223
/// Random forest regressor

0 commit comments

Comments
 (0)