@@ -884,6 +884,30 @@ def _score(estimator, X_test, y_test, scorer, error_score="raise"):
884
884
return scores
885
885
886
886
887
+ @validate_params (
888
+ {
889
+ "estimator" : [HasMethods (["fit" , "predict" ])],
890
+ "X" : ["array-like" , "sparse matrix" ],
891
+ "y" : ["array-like" , None ],
892
+ "groups" : ["array-like" , None ],
893
+ "cv" : ["cv_object" ],
894
+ "n_jobs" : [Integral , None ],
895
+ "verbose" : ["verbose" ],
896
+ "fit_params" : [dict , None ],
897
+ "pre_dispatch" : [Integral , str , None ],
898
+ "method" : [
899
+ StrOptions (
900
+ {
901
+ "predict" ,
902
+ "predict_proba" ,
903
+ "predict_log_proba" ,
904
+ "decision_function" ,
905
+ }
906
+ )
907
+ ],
908
+ },
909
+ prefer_skip_nested_validation = False , # estimator is not validated yet
910
+ )
887
911
def cross_val_predict (
888
912
estimator ,
889
913
X ,
@@ -912,10 +936,11 @@ def cross_val_predict(
912
936
913
937
Parameters
914
938
----------
915
- estimator : estimator object implementing 'fit' and 'predict'
916
- The object to use to fit the data.
939
+ estimator : estimator
940
+ The estimator instance to use to fit the data. It must implement a `fit`
941
+ method and the method given by the `method` parameter.
917
942
918
- X : array-like of shape (n_samples, n_features)
943
+ X : { array-like, sparse matrix} of shape (n_samples, n_features)
919
944
The data to fit. Can be, for example a list, or an array at least 2d.
920
945
921
946
y : array-like of shape (n_samples,) or (n_samples, n_outputs), \
0 commit comments