You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Add documentation and API
* Add public keyword
* Implement test_indices (debug version)
* Return indices as Vec of Vec
* Consume vector using drain()
* Use shape() to return num of samples
* Implement test_masks
* Implement KFold.split()
* Make trait public
* Add test for split
* Fix samples in shape()
* Implement shuffle
* Simplify return values
* Use usize for n_splits
Co-authored-by: VolodymyrOrlov <volodymyr.orlov@gmail.com>
let indices:Vec<usize> = (0..n_samples).collect();
170
+
171
+
letmut return_values:Vec<(Vec<usize>,Vec<usize>)> = Vec::with_capacity(self.n_splits);// TODO: init nested vecs with capacities by getting the length of test_index vecs
172
+
173
+
for test_index inself.test_masks(x).drain(..){
174
+
let train_index = indices
175
+
.clone()
176
+
.iter()
177
+
.enumerate()
178
+
.filter(|&(idx, _)| test_index[idx] == false)
179
+
.map(|(idx, _)| idx)
180
+
.collect::<Vec<usize>>();// filter train indices out according to mask
181
+
let test_index = indices
182
+
.iter()
183
+
.enumerate()
184
+
.filter(|&(idx, _)| test_index[idx] == true)
185
+
.map(|(idx, _)| idx)
186
+
.collect::<Vec<usize>>();// filter tests indices out according to mask
187
+
return_values.push((train_index, test_index))
188
+
}
189
+
return_values
190
+
}
191
+
}
192
+
84
193
#[cfg(test)]
85
194
mod tests {
86
195
@@ -106,4 +215,127 @@ mod tests {
106
215
assert_eq!(x_train.shape().0, y_train.len());
107
216
assert_eq!(x_test.shape().0, y_test.len());
108
217
}
218
+
219
+
#[test]
220
+
fnrun_kfold_return_test_indices_simple(){
221
+
let k = KFold{
222
+
n_splits:3,
223
+
shuffle:false,
224
+
};
225
+
let x:DenseMatrix<f64> = DenseMatrix::rand(33,100);
0 commit comments