Skip to content

Commit a2588f6

Browse files
KFold cross-validation (#8)
* Add documentation and API * Add public keyword * Implement test_indices (debug version) * Return indices as Vec of Vec * Consume vector using drain() * Use shape() to return num of samples * Implement test_masks * Implement KFold.split() * Make trait public * Add test for split * Fix samples in shape() * Implement shuffle * Simplify return values * Use usize for n_splits Co-authored-by: VolodymyrOrlov <volodymyr.orlov@gmail.com>
1 parent bb96354 commit a2588f6

File tree

1 file changed

+232
-0
lines changed

1 file changed

+232
-0
lines changed

src/model_selection/mod.rs

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ extern crate rand;
1313
use crate::linalg::BaseVector;
1414
use crate::linalg::Matrix;
1515
use crate::math::num::RealNumber;
16+
use rand::seq::SliceRandom;
17+
use rand::thread_rng;
1618
use rand::Rng;
1719

1820
/// Splits data into 2 disjoint datasets.
@@ -81,6 +83,113 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
8183
(x_train, x_test, y_train, y_test)
8284
}
8385

86+
///
87+
/// KFold Cross-Validation
88+
///
89+
pub trait BaseKFold {
90+
/// Returns integer indices corresponding to test sets
91+
fn test_indices<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Vec<Vec<usize>>;
92+
93+
/// Returns masksk corresponding to test sets
94+
fn test_masks<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Vec<Vec<bool>>;
95+
96+
/// Return a tuple containing the the training set indices for that split and
97+
/// the testing set indices for that split.
98+
fn split<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Vec<(Vec<usize>, Vec<usize>)>;
99+
}
100+
101+
///
102+
/// An implementation of KFold
103+
///
104+
pub struct KFold {
105+
n_splits: usize, // cannot exceed std::usize::MAX
106+
shuffle: bool,
107+
// TODO: to be implemented later
108+
// random_state: i32,
109+
}
110+
111+
impl Default for KFold {
112+
fn default() -> KFold {
113+
KFold {
114+
n_splits: 3 as usize,
115+
shuffle: true,
116+
}
117+
}
118+
}
119+
120+
///
121+
/// Abstract class for all KFold functionalities
122+
///
123+
impl BaseKFold for KFold {
124+
fn test_indices<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Vec<Vec<usize>> {
125+
// number of samples (rows) in the matrix
126+
let n_samples: usize = x.shape().0;
127+
128+
// initialise indices
129+
let mut indices: Vec<usize> = (0..n_samples).collect();
130+
if self.shuffle == true {
131+
indices.shuffle(&mut thread_rng());
132+
}
133+
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
134+
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
135+
136+
// increment by one if odd
137+
for i in 0..(n_samples % self.n_splits) {
138+
fold_sizes[i] = fold_sizes[i] + 1;
139+
}
140+
141+
// generate the right array of arrays for test indices
142+
let mut return_values: Vec<Vec<usize>> = Vec::with_capacity(self.n_splits);
143+
let mut current: usize = 0;
144+
for fold_size in fold_sizes.drain(..) {
145+
let stop = current + fold_size;
146+
return_values.push(indices[current..stop].to_vec());
147+
current = stop
148+
}
149+
150+
return_values
151+
}
152+
153+
fn test_masks<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Vec<Vec<bool>> {
154+
let mut return_values: Vec<Vec<bool>> = Vec::with_capacity(self.n_splits);
155+
for test_index in self.test_indices(x).drain(..) {
156+
// init mask
157+
let mut test_mask = vec![false; x.shape().0];
158+
// set mask's indices to true according to test indices
159+
for i in test_index {
160+
test_mask[i] = true; // can be implemented with map()
161+
}
162+
return_values.push(test_mask);
163+
}
164+
return_values
165+
}
166+
167+
fn split<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Vec<(Vec<usize>, Vec<usize>)> {
168+
let n_samples: usize = x.shape().0;
169+
let indices: Vec<usize> = (0..n_samples).collect();
170+
171+
let mut return_values: Vec<(Vec<usize>, Vec<usize>)> = Vec::with_capacity(self.n_splits); // TODO: init nested vecs with capacities by getting the length of test_index vecs
172+
173+
for test_index in self.test_masks(x).drain(..) {
174+
let train_index = indices
175+
.clone()
176+
.iter()
177+
.enumerate()
178+
.filter(|&(idx, _)| test_index[idx] == false)
179+
.map(|(idx, _)| idx)
180+
.collect::<Vec<usize>>(); // filter train indices out according to mask
181+
let test_index = indices
182+
.iter()
183+
.enumerate()
184+
.filter(|&(idx, _)| test_index[idx] == true)
185+
.map(|(idx, _)| idx)
186+
.collect::<Vec<usize>>(); // filter tests indices out according to mask
187+
return_values.push((train_index, test_index))
188+
}
189+
return_values
190+
}
191+
}
192+
84193
#[cfg(test)]
85194
mod tests {
86195

@@ -106,4 +215,127 @@ mod tests {
106215
assert_eq!(x_train.shape().0, y_train.len());
107216
assert_eq!(x_test.shape().0, y_test.len());
108217
}
218+
219+
#[test]
220+
fn run_kfold_return_test_indices_simple() {
221+
let k = KFold {
222+
n_splits: 3,
223+
shuffle: false,
224+
};
225+
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
226+
let test_indices = k.test_indices(&x);
227+
228+
assert_eq!(test_indices[0], (0..11).collect::<Vec<usize>>());
229+
assert_eq!(test_indices[1], (11..22).collect::<Vec<usize>>());
230+
assert_eq!(test_indices[2], (22..33).collect::<Vec<usize>>());
231+
}
232+
233+
#[test]
234+
fn run_kfold_return_test_indices_odd() {
235+
let k = KFold {
236+
n_splits: 3,
237+
shuffle: false,
238+
};
239+
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
240+
let test_indices = k.test_indices(&x);
241+
242+
assert_eq!(test_indices[0], (0..12).collect::<Vec<usize>>());
243+
assert_eq!(test_indices[1], (12..23).collect::<Vec<usize>>());
244+
assert_eq!(test_indices[2], (23..34).collect::<Vec<usize>>());
245+
}
246+
247+
#[test]
248+
fn run_kfold_return_test_mask_simple() {
249+
let k = KFold {
250+
n_splits: 2,
251+
shuffle: false,
252+
};
253+
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
254+
let test_masks = k.test_masks(&x);
255+
256+
for t in &test_masks[0][0..11] {
257+
// TODO: this can be prob done better
258+
assert_eq!(*t, true)
259+
}
260+
for t in &test_masks[0][11..22] {
261+
assert_eq!(*t, false)
262+
}
263+
264+
for t in &test_masks[1][0..11] {
265+
assert_eq!(*t, false)
266+
}
267+
for t in &test_masks[1][11..22] {
268+
assert_eq!(*t, true)
269+
}
270+
}
271+
272+
#[test]
273+
fn run_kfold_return_split_simple() {
274+
let k = KFold {
275+
n_splits: 2,
276+
shuffle: false,
277+
};
278+
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
279+
let train_test_splits = k.split(&x);
280+
281+
assert_eq!(train_test_splits[0].1, (0..11).collect::<Vec<usize>>());
282+
assert_eq!(train_test_splits[0].0, (11..22).collect::<Vec<usize>>());
283+
assert_eq!(train_test_splits[1].0, (0..11).collect::<Vec<usize>>());
284+
assert_eq!(train_test_splits[1].1, (11..22).collect::<Vec<usize>>());
285+
}
286+
287+
#[test]
288+
fn run_kfold_return_split_simple_shuffle() {
289+
let k = KFold {
290+
n_splits: 2,
291+
..KFold::default()
292+
};
293+
let x: DenseMatrix<f64> = DenseMatrix::rand(23, 100);
294+
let train_test_splits = k.split(&x);
295+
296+
assert_eq!(train_test_splits[0].1.len(), 12 as usize);
297+
assert_eq!(train_test_splits[0].0.len(), 11 as usize);
298+
assert_eq!(train_test_splits[1].0.len(), 12 as usize);
299+
assert_eq!(train_test_splits[1].1.len(), 11 as usize);
300+
}
301+
302+
#[test]
303+
fn numpy_parity_test() {
304+
let k = KFold {
305+
n_splits: 3,
306+
shuffle: false,
307+
};
308+
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
309+
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
310+
(vec![4, 5, 6, 7, 8, 9], vec![0, 1, 2, 3]),
311+
(vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
312+
(vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
313+
];
314+
for ((train, test), (expected_train, expected_test)) in
315+
k.split(&x).into_iter().zip(expected)
316+
{
317+
assert_eq!(test, expected_test);
318+
assert_eq!(train, expected_train);
319+
}
320+
}
321+
322+
#[test]
323+
fn numpy_parity_test_shuffle() {
324+
let k = KFold {
325+
n_splits: 3,
326+
..KFold::default()
327+
};
328+
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
329+
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
330+
(vec![4, 5, 6, 7, 8, 9], vec![0, 1, 2, 3]),
331+
(vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
332+
(vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
333+
];
334+
for ((train, test), (expected_train, expected_test)) in
335+
k.split(&x).into_iter().zip(expected)
336+
{
337+
assert_eq!(test.len(), expected_test.len());
338+
assert_eq!(train.len(), expected_train.len());
339+
}
340+
}
109341
}

0 commit comments

Comments
 (0)