Skip to content

Commit f685f57

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
feat: + cross_val_predict
1 parent 9b22197 commit f685f57

File tree

1 file changed

+48
-57
lines changed

1 file changed

+48
-57
lines changed

src/model_selection/mod.rs

Lines changed: 48 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ where
125125
let mut test_score = Vec::with_capacity(k);
126126
let mut train_score = Vec::with_capacity(k);
127127

128-
for (test_idx, train_idx) in cv.split(x) {
128+
for (train_idx, test_idx) in cv.split(x) {
129129
let train_x = x.take(&train_idx, 0);
130130
let train_y = y.take(&train_idx);
131131
let test_x = x.take(&test_idx, 0);
@@ -143,6 +143,46 @@ where
143143
})
144144
}
145145

146+
/// Generate cross-validated estimates for each input data point.
147+
/// The data is split according to the cv parameter. Each sample belongs to exactly one test set, and its prediction is computed with an estimator fitted on the corresponding training set.
148+
/// * `fit_estimator` - a `fit` function of an estimator
149+
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
150+
/// * `y` - target values, should be of size _N_
151+
/// * `parameters` - parameters of selected estimator. Use `Default::default()` for default parameters.
152+
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
153+
pub fn cross_val_predict<T, M, H, E, K, F>(
154+
fit_estimator: F,
155+
x: &M,
156+
y: &M::RowVector,
157+
parameters: H,
158+
cv: K
159+
) -> Result<M::RowVector, Failed>
160+
where
161+
T: RealNumber,
162+
M: Matrix<T>,
163+
H: Clone,
164+
E: Predictor<M, M::RowVector>,
165+
K: BaseKFold,
166+
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>
167+
{
168+
let mut y_hat = M::RowVector::zeros(y.len());
169+
170+
for (train_idx, test_idx) in cv.split(x) {
171+
let train_x = x.take(&train_idx, 0);
172+
let train_y = y.take(&train_idx);
173+
let test_x = x.take(&test_idx, 0);
174+
175+
let estimator = fit_estimator(&train_x, &train_y, parameters.clone())?;
176+
177+
let y_test_hat = estimator.predict(&test_x)?;
178+
for (i, &idx) in test_idx.iter().enumerate() {
179+
y_hat.set(idx, y_test_hat.get(i));
180+
}
181+
}
182+
183+
Ok(y_hat)
184+
}
185+
146186
#[cfg(test)]
147187
mod tests {
148188

@@ -278,10 +318,8 @@ mod tests {
278318
assert!(results.mean_train_score() < results.mean_test_score());
279319
}
280320

281-
use crate::tree::decision_tree_regressor::*;
282-
283321
#[test]
284-
fn test_some_regressor() {
322+
fn test_cross_val_predict_knn() {
285323
let x = DenseMatrix::from_2d_array(&[
286324
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
287325
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
@@ -305,68 +343,21 @@ mod tests {
305343
114.2, 115.7, 116.9,
306344
];
307345

308-
let cv = KFold::default().with_n_splits(2);
309-
310-
let results = cross_validate(
311-
DecisionTreeRegressor::fit,
312-
&x,
313-
&y,
314-
Default::default(),
315-
cv,
316-
&mean_absolute_error,
317-
)
318-
.unwrap();
319-
320-
println!("{}", results.mean_test_score());
321-
println!("{}", results.mean_train_score());
322-
}
323-
324-
use crate::tree::decision_tree_classifier::*;
325-
326-
#[test]
327-
fn test_some_classifier() {
328-
let x = DenseMatrix::from_2d_array(&[
329-
&[5.1, 3.5, 1.4, 0.2],
330-
&[4.9, 3.0, 1.4, 0.2],
331-
&[4.7, 3.2, 1.3, 0.2],
332-
&[4.6, 3.1, 1.5, 0.2],
333-
&[5.0, 3.6, 1.4, 0.2],
334-
&[5.4, 3.9, 1.7, 0.4],
335-
&[4.6, 3.4, 1.4, 0.3],
336-
&[5.0, 3.4, 1.5, 0.2],
337-
&[4.4, 2.9, 1.4, 0.2],
338-
&[4.9, 3.1, 1.5, 0.1],
339-
&[7.0, 3.2, 4.7, 1.4],
340-
&[6.4, 3.2, 4.5, 1.5],
341-
&[6.9, 3.1, 4.9, 1.5],
342-
&[5.5, 2.3, 4.0, 1.3],
343-
&[6.5, 2.8, 4.6, 1.5],
344-
&[5.7, 2.8, 4.5, 1.3],
345-
&[6.3, 3.3, 4.7, 1.6],
346-
&[4.9, 2.4, 3.3, 1.0],
347-
&[6.6, 2.9, 4.6, 1.3],
348-
&[5.2, 2.7, 3.9, 1.4],
349-
]);
350-
let y = vec![
351-
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
352-
];
353-
354346
let cv = KFold {
355347
n_splits: 2,
356348
..KFold::default()
357349
};
358350

359-
let results = cross_validate(
360-
DecisionTreeClassifier::fit,
351+
let y_hat = cross_val_predict(
352+
KNNRegressor::fit,
361353
&x,
362354
&y,
363355
Default::default(),
364-
cv,
365-
&accuracy,
356+
cv
366357
)
367-
.unwrap();
358+
.unwrap();
368359

369-
println!("{}", results.mean_test_score());
370-
println!("{}", results.mean_train_score());
360+
assert!(mean_absolute_error(&y, &y_hat) < 10.0);
371361
}
362+
372363
}

0 commit comments

Comments
 (0)