Skip to content

Commit 0aa8932

Browse files
FEA Add metadata routing to OrthogonalMatchingPursuitCV (scikit-learn#27500)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 33e6a8d commit 0aa8932

File tree

4 files changed

+63
-9
lines changed

4 files changed

+63
-9
lines changed

doc/metadata_routing.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ Meta-estimators and functions supporting metadata routing:
266266
- :class:`sklearn.multioutput.ClassifierChain`
267267
- :class:`sklearn.multioutput.MultiOutputClassifier`
268268
- :class:`sklearn.multioutput.MultiOutputRegressor`
269+
- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV`
269270
- :class:`sklearn.multioutput.RegressorChain`
270271
- :class:`sklearn.pipeline.Pipeline`
271272

@@ -292,7 +293,6 @@ Meta-estimators and tools not supporting metadata routing yet:
292293
- :class:`sklearn.linear_model.LassoLarsCV`
293294
- :class:`sklearn.linear_model.MultiTaskElasticNetCV`
294295
- :class:`sklearn.linear_model.MultiTaskLassoCV`
295-
- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV`
296296
- :class:`sklearn.linear_model.RANSACRegressor`
297297
- :class:`sklearn.linear_model.RidgeClassifierCV`
298298
- :class:`sklearn.linear_model.RidgeCV`

doc/whats_new/v1.4.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ more details.
9696
``**score_params`` which are passed to the underlying scorer.
9797
:pr:`26525` by :user:`Omar Salman <OmarManzoor>`.
9898

99+
- |Feature| :class:`linear_model.OrthogonalMatchingPursuitCV` now supports
100+
metadata routing. Its `fit` now accepts ``**fit_params``, which are passed to
101+
the underlying splitter. :pr:`27500` by :user:`Stefanie Senger
102+
<StefanieSenger>`.
103+
99104
- |Fix| All meta-estimators for which metadata routing is not yet implemented
100105
now raise a `NotImplementedError` on `get_metadata_routing` and on `fit` if
101106
metadata routing is enabled and any metadata is passed to them. :pr:`27389`

sklearn/linear_model/_omp.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@
1515

1616
from ..base import MultiOutputMixin, RegressorMixin, _fit_context
1717
from ..model_selection import check_cv
18-
from ..utils import as_float_array, check_array
18+
from ..utils import Bunch, as_float_array, check_array
1919
from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
20-
from ..utils.metadata_routing import _RoutingNotSupportedMixin
20+
from ..utils.metadata_routing import (
21+
MetadataRouter,
22+
MethodMapping,
23+
_raise_for_params,
24+
_routing_enabled,
25+
process_routing,
26+
)
2127
from ..utils.parallel import Parallel, delayed
2228
from ._base import LinearModel, _deprecate_normalize, _pre_fit
2329

@@ -904,9 +910,7 @@ def _omp_path_residues(
904910
return np.dot(coefs.T, X_test.T) - y_test
905911

906912

907-
class OrthogonalMatchingPursuitCV(
908-
_RoutingNotSupportedMixin, RegressorMixin, LinearModel
909-
):
913+
class OrthogonalMatchingPursuitCV(RegressorMixin, LinearModel):
910914
"""Cross-validated Orthogonal Matching Pursuit model (OMP).
911915
912916
See glossary entry for :term:`cross-validation estimator`.
@@ -1060,7 +1064,7 @@ def __init__(
10601064
self.verbose = verbose
10611065

10621066
@_fit_context(prefer_skip_nested_validation=True)
1063-
def fit(self, X, y):
1067+
def fit(self, X, y, **fit_params):
10641068
"""Fit the model using X, y as training data.
10651069
10661070
Parameters
@@ -1071,18 +1075,36 @@ def fit(self, X, y):
10711075
y : array-like of shape (n_samples,)
10721076
Target values. Will be cast to X's dtype if necessary.
10731077
1078+
**fit_params : dict
1079+
Parameters to pass to the underlying splitter.
1080+
1081+
.. versionadded:: 1.4
1082+
Only available if `enable_metadata_routing=True`,
1083+
which can be set by using
1084+
``sklearn.set_config(enable_metadata_routing=True)``.
1085+
See :ref:`Metadata Routing User Guide <metadata_routing>` for
1086+
more details.
1087+
10741088
Returns
10751089
-------
10761090
self : object
10771091
Returns an instance of self.
10781092
"""
1093+
_raise_for_params(fit_params, self, "fit")
1094+
10791095
_normalize = _deprecate_normalize(
10801096
self.normalize, estimator_name=self.__class__.__name__
10811097
)
10821098

10831099
X, y = self._validate_data(X, y, y_numeric=True, ensure_min_features=2)
10841100
X = as_float_array(X, copy=False, force_all_finite=False)
10851101
cv = check_cv(self.cv, classifier=False)
1102+
if _routing_enabled():
1103+
routed_params = process_routing(self, "fit", **fit_params)
1104+
else:
1105+
# TODO(SLEP6): remove when metadata routing cannot be disabled.
1106+
routed_params = Bunch()
1107+
routed_params.splitter = Bunch(split={})
10861108
max_iter = (
10871109
min(max(int(0.1 * X.shape[1]), 5), X.shape[1])
10881110
if not self.max_iter
@@ -1099,7 +1121,7 @@ def fit(self, X, y):
10991121
_normalize,
11001122
max_iter,
11011123
)
1102-
for train, test in cv.split(X)
1124+
for train, test in cv.split(X, **routed_params.splitter.split)
11031125
)
11041126

11051127
min_early_stop = min(fold.shape[0] for fold in cv_paths)
@@ -1123,3 +1145,24 @@ def fit(self, X, y):
11231145
self.intercept_ = omp.intercept_
11241146
self.n_iter_ = omp.n_iter_
11251147
return self
1148+
1149+
def get_metadata_routing(self):
1150+
"""Get metadata routing of this object.
1151+
1152+
Please check :ref:`User Guide <metadata_routing>` on how the routing
1153+
mechanism works.
1154+
1155+
.. versionadded:: 1.4
1156+
1157+
Returns
1158+
-------
1159+
routing : MetadataRouter
1160+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
1161+
routing information.
1162+
"""
1163+
1164+
router = MetadataRouter(owner=self.__class__.__name__).add(
1165+
splitter=self.cv,
1166+
method_mapping=MethodMapping().add(callee="split", caller="fit"),
1167+
)
1168+
return router

sklearn/tests/test_metaestimators_metadata_routing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ def enable_slep006():
228228
"y": y,
229229
"estimator_routing_methods": ["fit"],
230230
},
231+
{
232+
"metaestimator": OrthogonalMatchingPursuitCV,
233+
"X": X,
234+
"y": y,
235+
"cv_name": "cv",
236+
"cv_routing_methods": ["fit"],
237+
},
231238
]
232239
"""List containing all metaestimators to be tested and their settings
233240
@@ -277,7 +284,6 @@ def enable_slep006():
277284
LassoLarsCV(),
278285
MultiTaskElasticNetCV(),
279286
MultiTaskLassoCV(),
280-
OrthogonalMatchingPursuitCV(),
281287
RANSACRegressor(),
282288
RFE(ConsumingClassifier()),
283289
RFECV(ConsumingClassifier()),

0 commit comments

Comments
 (0)