@@ -155,22 +155,22 @@ pub fn cross_val_predict<T, M, H, E, K, F>(
155
155
x : & M ,
156
156
y : & M :: RowVector ,
157
157
parameters : H ,
158
- cv : K
158
+ cv : K ,
159
159
) -> Result < M :: RowVector , Failed >
160
160
where
161
161
T : RealNumber ,
162
162
M : Matrix < T > ,
163
163
H : Clone ,
164
164
E : Predictor < M , M :: RowVector > ,
165
165
K : BaseKFold ,
166
- F : Fn ( & M , & M :: RowVector , H ) -> Result < E , Failed >
167
- {
168
- let mut y_hat = M :: RowVector :: zeros ( y. len ( ) ) ;
169
-
166
+ F : Fn ( & M , & M :: RowVector , H ) -> Result < E , Failed > ,
167
+ {
168
+ let mut y_hat = M :: RowVector :: zeros ( y. len ( ) ) ;
169
+
170
170
for ( train_idx, test_idx) in cv. split ( x) {
171
171
let train_x = x. take ( & train_idx, 0 ) ;
172
172
let train_y = y. take ( & train_idx) ;
173
- let test_x = x. take ( & test_idx, 0 ) ;
173
+ let test_x = x. take ( & test_idx, 0 ) ;
174
174
175
175
let estimator = fit_estimator ( & train_x, & train_y, parameters. clone ( ) ) ?;
176
176
@@ -348,16 +348,8 @@ mod tests {
348
348
..KFold :: default ( )
349
349
} ;
350
350
351
- let y_hat = cross_val_predict (
352
- KNNRegressor :: fit,
353
- & x,
354
- & y,
355
- Default :: default ( ) ,
356
- cv
357
- )
358
- . unwrap ( ) ;
351
+ let y_hat = cross_val_predict ( KNNRegressor :: fit, & x, & y, Default :: default ( ) , cv) . unwrap ( ) ;
359
352
360
353
assert ! ( mean_absolute_error( & y, & y_hat) < 10.0 ) ;
361
354
}
362
-
363
355
}
0 commit comments