Skip to content

Commit 7f871fe

Browse files
MAINT Parameters validation for sklearn.model_selection.cross_val_predict (scikit-learn#26252)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 845771f commit 7f871fe

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

sklearn/model_selection/_validation.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,30 @@ def _score(estimator, X_test, y_test, scorer, error_score="raise"):
884884
return scores
885885

886886

887+
@validate_params(
888+
{
889+
"estimator": [HasMethods(["fit", "predict"])],
890+
"X": ["array-like", "sparse matrix"],
891+
"y": ["array-like", None],
892+
"groups": ["array-like", None],
893+
"cv": ["cv_object"],
894+
"n_jobs": [Integral, None],
895+
"verbose": ["verbose"],
896+
"fit_params": [dict, None],
897+
"pre_dispatch": [Integral, str, None],
898+
"method": [
899+
StrOptions(
900+
{
901+
"predict",
902+
"predict_proba",
903+
"predict_log_proba",
904+
"decision_function",
905+
}
906+
)
907+
],
908+
},
909+
prefer_skip_nested_validation=False, # estimator is not validated yet
910+
)
887911
def cross_val_predict(
888912
estimator,
889913
X,
@@ -912,10 +936,11 @@ def cross_val_predict(
912936
913937
Parameters
914938
----------
915-
estimator : estimator object implementing 'fit' and 'predict'
916-
The object to use to fit the data.
939+
estimator : estimator
940+
The estimator instance to use to fit the data. It must implement a `fit`
941+
method and the method given by the `method` parameter.
917942
918-
X : array-like of shape (n_samples, n_features)
943+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
919944
The data to fit. Can be, for example a list, or an array at least 2d.
920945
921946
y : array-like of shape (n_samples,) or (n_samples, n_outputs), \

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def _check_function_param_validation(
269269
"sklearn.metrics.top_k_accuracy_score",
270270
"sklearn.metrics.v_measure_score",
271271
"sklearn.metrics.zero_one_loss",
272+
"sklearn.model_selection.cross_val_predict",
272273
"sklearn.model_selection.cross_val_score",
273274
"sklearn.model_selection.cross_validate",
274275
"sklearn.model_selection.learning_curve",

0 commit comments

Comments
 (0)