Skip to content

Commit d28f13d

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
feat: adds train/test split function; fixes bug in random forest
1 parent 1920f9c commit d28f13d

File tree

9 files changed

+187
-10
lines changed

9 files changed

+187
-10
lines changed

src/ensemble/random_forest_classifier.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,19 @@ impl<T: RealNumber> RandomForestClassifier<T> {
199199
let nrows = y.len();
200200
let mut samples = vec![0; nrows];
201201
for l in 0..num_classes {
202-
let mut nj = 0;
203-
let mut cj: Vec<usize> = Vec::new();
202+
let mut n_samples = 0;
203+
let mut index: Vec<usize> = Vec::new();
204204
for i in 0..nrows {
205205
if y[i] == l {
206-
cj.push(i);
207-
nj += 1;
206+
index.push(i);
207+
n_samples += 1;
208208
}
209209
}
210210

211-
let size = ((nj as f64) / class_weight[l]) as usize;
211+
let size = ((n_samples as f64) / class_weight[l]) as usize;
212212
for _ in 0..size {
213-
let xi: usize = rng.gen_range(0, nj);
214-
samples[cj[xi]] += 1;
213+
let xi: usize = rng.gen_range(0, n_samples);
214+
samples[index[xi]] += 1;
215215
}
216216
}
217217
samples
@@ -260,12 +260,12 @@ mod tests {
260260
max_depth: None,
261261
min_samples_leaf: 1,
262262
min_samples_split: 2,
263-
n_trees: 1000,
263+
n_trees: 100,
264264
m: Option::None,
265265
},
266266
);
267267

268-
assert!(accuracy(&y, &classifier.predict(&x)) > 0.9);
268+
assert!(accuracy(&y, &classifier.predict(&x)) >= 0.95);
269269
}
270270

271271
#[test]

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ pub mod linear;
8383
pub mod math;
8484
/// Functions for assessing prediction error.
8585
pub mod metrics;
86+
pub mod model_selection;
8687
/// Supervised neighbors-based learning methods
8788
pub mod neighbors;
8889
pub(crate) mod optimization;

src/linalg/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
7676

7777
/// Return a vector with the elements of the one-dimensional array.
7878
fn to_vec(&self) -> Vec<T>;
79+
80+
/// Create new vector with zeros of size `len`.
81+
fn zeros(len: usize) -> Self;
82+
83+
/// Create new vector with ones of size `len`.
84+
fn ones(len: usize) -> Self;
85+
86+
/// Create new vector of size `len` where each element is set to `value`.
87+
fn fill(len: usize, value: T) -> Self;
7988
}
8089

8190
/// Generic matrix type.

src/linalg/naive/dense_matrix.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ impl<T: RealNumber> BaseVector<T> for Vec<T> {
3232
let v = self.clone();
3333
v
3434
}
35+
36+
fn zeros(len: usize) -> Self {
37+
vec![T::zero(); len]
38+
}
39+
40+
fn ones(len: usize) -> Self {
41+
vec![T::one(); len]
42+
}
43+
44+
fn fill(len: usize, value: T) -> Self {
45+
vec![value; len]
46+
}
3547
}
3648

3749
/// Column-major, dense matrix. See [Simple Dense Matrix](../index.html).

src/linalg/nalgebra_bindings.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
use std::iter::Sum;
4141
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
4242

43-
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, Scalar, VecStorage, U1};
43+
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, RowDVector, Scalar, VecStorage, U1};
4444

4545
use crate::linalg::evd::EVDDecomposableMatrix;
4646
use crate::linalg::lu::LUDecomposableMatrix;
@@ -65,6 +65,20 @@ impl<T: RealNumber + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
6565
fn to_vec(&self) -> Vec<T> {
6666
self.row(0).iter().map(|v| *v).collect()
6767
}
68+
69+
fn zeros(len: usize) -> Self {
70+
RowDVector::zeros(len)
71+
}
72+
73+
fn ones(len: usize) -> Self {
74+
BaseVector::fill(len, T::one())
75+
}
76+
77+
fn fill(len: usize, value: T) -> Self {
78+
let mut m = RowDVector::zeros(len);
79+
m.fill(value);
80+
m
81+
}
6882
}
6983

7084
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
@@ -446,6 +460,16 @@ mod tests {
446460
assert_eq!(vec![1., 2., 3.], v.to_vec());
447461
}
448462

463+
#[test]
464+
fn vec_init() {
465+
let zeros: RowDVector<f32> = BaseVector::zeros(3);
466+
let ones: RowDVector<f32> = BaseVector::ones(3);
467+
let twos: RowDVector<f32> = BaseVector::fill(3, 2.);
468+
assert_eq!(zeros, RowDVector::from_vec(vec![0., 0., 0.]));
469+
assert_eq!(ones, RowDVector::from_vec(vec![1., 1., 1.]));
470+
assert_eq!(twos, RowDVector::from_vec(vec![2., 2., 2.]));
471+
}
472+
449473
#[test]
450474
fn get_set_dynamic() {
451475
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);

src/linalg/ndarray_bindings.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ impl<T: RealNumber> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> {
7272
fn to_vec(&self) -> Vec<T> {
7373
self.to_owned().to_vec()
7474
}
75+
76+
fn zeros(len: usize) -> Self {
77+
Array::zeros(len)
78+
}
79+
80+
fn ones(len: usize) -> Self {
81+
Array::ones(len)
82+
}
83+
84+
fn fill(len: usize, value: T) -> Self {
85+
Array::from_elem(len, value)
86+
}
7587
}
7688

7789
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>

src/model_selection/mod.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
//! # Model Selection methods
2+
//!
3+
//! In statistics and machine learning we usually split our data into multiple subsets: training data and testing data (and sometimes to validate),
4+
//! and fit our model on the train data, in order to make predictions on the test data. We do that to avoid overfitting or underfitting model to our data.
5+
//! Overfitting is bad because the model we trained fits trained data too well and can’t make any inferences on new data.
6+
//! Underfitted is bad because the model is undetrained and does not fit the training data well.
7+
//! Splitting data into multiple subsets helps to find the right combination of hyperparameters, estimate model performance and choose the right model for
8+
//! your data.
9+
//!
10+
//! In SmartCore you can split your data into training and test datasets using `train_test_split` function.
11+
extern crate rand;
12+
13+
use crate::linalg::BaseVector;
14+
use crate::linalg::Matrix;
15+
use crate::math::num::RealNumber;
16+
use rand::Rng;
17+
18+
/// Splits data into 2 disjoint datasets.
19+
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
20+
/// * `y` - target values, should be of size _M_
21+
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
22+
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
23+
x: &M,
24+
y: &M::RowVector,
25+
test_size: f32,
26+
) -> (M, M, M::RowVector, M::RowVector) {
27+
if x.shape().0 != y.len() {
28+
panic!(
29+
"x and y should have the same number of samples. |x|: {}, |y|: {}",
30+
x.shape().0,
31+
y.len()
32+
);
33+
}
34+
35+
if test_size <= 0. || test_size > 1.0 {
36+
panic!("test_size should be between 0 and 1");
37+
}
38+
39+
let n = y.len();
40+
let m = x.shape().1;
41+
42+
let mut rng = rand::thread_rng();
43+
let mut n_test = 0;
44+
let mut index = vec![false; n];
45+
46+
for i in 0..n {
47+
let p_test: f32 = rng.gen();
48+
if p_test <= test_size {
49+
index[i] = true;
50+
n_test += 1;
51+
}
52+
}
53+
54+
let n_train = n - n_test;
55+
56+
let mut x_train = M::zeros(n_train, m);
57+
let mut x_test = M::zeros(n_test, m);
58+
let mut y_train = M::RowVector::zeros(n_train);
59+
let mut y_test = M::RowVector::zeros(n_test);
60+
61+
let mut r_train = 0;
62+
let mut r_test = 0;
63+
64+
for r in 0..n {
65+
if index[r] {
66+
//sample belongs to test
67+
for c in 0..m {
68+
x_test.set(r_test, c, x.get(r, c));
69+
y_test.set(r_test, y.get(r));
70+
}
71+
r_test += 1;
72+
} else {
73+
for c in 0..m {
74+
x_train.set(r_train, c, x.get(r, c));
75+
y_train.set(r_train, y.get(r));
76+
}
77+
r_train += 1;
78+
}
79+
}
80+
81+
(x_train, x_test, y_train, y_test)
82+
}
83+
84+
#[cfg(test)]
85+
mod tests {
86+
87+
use super::*;
88+
use crate::linalg::naive::dense_matrix::*;
89+
90+
#[test]
91+
fn run_train_test_split() {
92+
let n = 100;
93+
let x: DenseMatrix<f64> = DenseMatrix::rand(100, 3);
94+
let y = vec![0f64; 100];
95+
96+
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2);
97+
98+
assert!(
99+
x_train.shape().0 > (n as f64 * 0.65) as usize
100+
&& x_train.shape().0 < (n as f64 * 0.95) as usize
101+
);
102+
assert!(
103+
x_test.shape().0 > (n as f64 * 0.05) as usize
104+
&& x_test.shape().0 < (n as f64 * 0.35) as usize
105+
);
106+
assert_eq!(x_train.shape().0, y_train.len());
107+
assert_eq!(x_test.shape().0, y_test.len());
108+
}
109+
}

src/tree/decision_tree_classifier.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ use std::default::Default;
6767
use std::fmt::Debug;
6868
use std::marker::PhantomData;
6969

70+
use rand::seq::SliceRandom;
7071
use serde::{Deserialize, Serialize};
7172

7273
use crate::algorithm::sort::quick_sort::QuickArgSort;
@@ -431,6 +432,10 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
431432
variables[i] = i;
432433
}
433434

435+
if mtry < n_attr {
436+
variables.shuffle(&mut rand::thread_rng());
437+
}
438+
434439
for j in 0..mtry {
435440
self.find_best_split(
436441
visitor,

src/tree/decision_tree_regressor.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ use std::collections::LinkedList;
6262
use std::default::Default;
6363
use std::fmt::Debug;
6464

65+
use rand::seq::SliceRandom;
6566
use serde::{Deserialize, Serialize};
6667

6768
use crate::algorithm::sort::quick_sort::QuickArgSort;
@@ -320,6 +321,10 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
320321
variables[i] = i;
321322
}
322323

324+
if mtry < n_attr {
325+
variables.shuffle(&mut rand::thread_rng());
326+
}
327+
323328
let parent_gain =
324329
T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output;
325330

0 commit comments

Comments
 (0)