15
15
16
16
from ..base import MultiOutputMixin , RegressorMixin , _fit_context
17
17
from ..model_selection import check_cv
18
- from ..utils import as_float_array , check_array
18
+ from ..utils import Bunch , as_float_array , check_array
19
19
from ..utils ._param_validation import Hidden , Interval , StrOptions , validate_params
20
- from ..utils .metadata_routing import _RoutingNotSupportedMixin
20
+ from ..utils .metadata_routing import (
21
+ MetadataRouter ,
22
+ MethodMapping ,
23
+ _raise_for_params ,
24
+ _routing_enabled ,
25
+ process_routing ,
26
+ )
21
27
from ..utils .parallel import Parallel , delayed
22
28
from ._base import LinearModel , _deprecate_normalize , _pre_fit
23
29
@@ -904,9 +910,7 @@ def _omp_path_residues(
904
910
return np .dot (coefs .T , X_test .T ) - y_test
905
911
906
912
907
- class OrthogonalMatchingPursuitCV (
908
- _RoutingNotSupportedMixin , RegressorMixin , LinearModel
909
- ):
913
+ class OrthogonalMatchingPursuitCV (RegressorMixin , LinearModel ):
910
914
"""Cross-validated Orthogonal Matching Pursuit model (OMP).
911
915
912
916
See glossary entry for :term:`cross-validation estimator`.
@@ -1060,7 +1064,7 @@ def __init__(
1060
1064
self .verbose = verbose
1061
1065
1062
1066
@_fit_context (prefer_skip_nested_validation = True )
1063
- def fit (self , X , y ):
1067
+ def fit (self , X , y , ** fit_params ):
1064
1068
"""Fit the model using X, y as training data.
1065
1069
1066
1070
Parameters
@@ -1071,18 +1075,36 @@ def fit(self, X, y):
1071
1075
y : array-like of shape (n_samples,)
1072
1076
Target values. Will be cast to X's dtype if necessary.
1073
1077
1078
+ **fit_params : dict
1079
+ Parameters to pass to the underlying splitter.
1080
+
1081
+ .. versionadded:: 1.4
1082
+ Only available if `enable_metadata_routing=True`,
1083
+ which can be set by using
1084
+ ``sklearn.set_config(enable_metadata_routing=True)``.
1085
+ See :ref:`Metadata Routing User Guide <metadata_routing>` for
1086
+ more details.
1087
+
1074
1088
Returns
1075
1089
-------
1076
1090
self : object
1077
1091
Returns an instance of self.
1078
1092
"""
1093
+ _raise_for_params (fit_params , self , "fit" )
1094
+
1079
1095
_normalize = _deprecate_normalize (
1080
1096
self .normalize , estimator_name = self .__class__ .__name__
1081
1097
)
1082
1098
1083
1099
X , y = self ._validate_data (X , y , y_numeric = True , ensure_min_features = 2 )
1084
1100
X = as_float_array (X , copy = False , force_all_finite = False )
1085
1101
cv = check_cv (self .cv , classifier = False )
1102
+ if _routing_enabled ():
1103
+ routed_params = process_routing (self , "fit" , ** fit_params )
1104
+ else :
1105
+ # TODO(SLEP6): remove when metadata routing cannot be disabled.
1106
+ routed_params = Bunch ()
1107
+ routed_params .splitter = Bunch (split = {})
1086
1108
max_iter = (
1087
1109
min (max (int (0.1 * X .shape [1 ]), 5 ), X .shape [1 ])
1088
1110
if not self .max_iter
@@ -1099,7 +1121,7 @@ def fit(self, X, y):
1099
1121
_normalize ,
1100
1122
max_iter ,
1101
1123
)
1102
- for train , test in cv .split (X )
1124
+ for train , test in cv .split (X , ** routed_params . splitter . split )
1103
1125
)
1104
1126
1105
1127
min_early_stop = min (fold .shape [0 ] for fold in cv_paths )
@@ -1123,3 +1145,24 @@ def fit(self, X, y):
1123
1145
self .intercept_ = omp .intercept_
1124
1146
self .n_iter_ = omp .n_iter_
1125
1147
return self
1148
+
1149
+ def get_metadata_routing (self ):
1150
+ """Get metadata routing of this object.
1151
+
1152
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
1153
+ mechanism works.
1154
+
1155
+ .. versionadded:: 1.4
1156
+
1157
+ Returns
1158
+ -------
1159
+ routing : MetadataRouter
1160
+ A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
1161
+ routing information.
1162
+ """
1163
+
1164
+ router = MetadataRouter (owner = self .__class__ .__name__ ).add (
1165
+ splitter = self .cv ,
1166
+ method_mapping = MethodMapping ().add (callee = "split" , caller = "fit" ),
1167
+ )
1168
+ return router
0 commit comments