You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/model_selection/mod.rs
+48-57Lines changed: 48 additions & 57 deletions
Original file line number
Diff line number
Diff line change
@@ -125,7 +125,7 @@ where
125
125
letmut test_score = Vec::with_capacity(k);
126
126
letmut train_score = Vec::with_capacity(k);
127
127
128
-
for(test_idx, train_idx)in cv.split(x){
128
+
for(train_idx, test_idx)in cv.split(x){
129
129
let train_x = x.take(&train_idx,0);
130
130
let train_y = y.take(&train_idx);
131
131
let test_x = x.take(&test_idx,0);
@@ -143,6 +143,46 @@ where
143
143
})
144
144
}
145
145
146
+
/// Generate cross-validated estimates for each input data point.
147
+
/// The data is split according to the cv parameter. Each sample belongs to exactly one test set, and its prediction is computed with an estimator fitted on the corresponding training set.
148
+
/// * `fit_estimator` - a `fit` function of an estimator
149
+
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
150
+
/// * `y` - target values, should be of size _N_
151
+
/// * `parameters` - parameters of selected estimator. Use `Default::default()` for default parameters.
152
+
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
153
+
pubfncross_val_predict<T,M,H,E,K,F>(
154
+
fit_estimator:F,
155
+
x:&M,
156
+
y:&M::RowVector,
157
+
parameters:H,
158
+
cv:K
159
+
) -> Result<M::RowVector,Failed>
160
+
where
161
+
T:RealNumber,
162
+
M:Matrix<T>,
163
+
H:Clone,
164
+
E:Predictor<M,M::RowVector>,
165
+
K:BaseKFold,
166
+
F:Fn(&M,&M::RowVector,H) -> Result<E,Failed>
167
+
{
168
+
letmut y_hat = M::RowVector::zeros(y.len());
169
+
170
+
for(train_idx, test_idx)in cv.split(x){
171
+
let train_x = x.take(&train_idx,0);
172
+
let train_y = y.take(&train_idx);
173
+
let test_x = x.take(&test_idx,0);
174
+
175
+
let estimator = fit_estimator(&train_x,&train_y, parameters.clone())?;
0 commit comments