Skip to content

Commit 83a8015

Browse files
authored
FEAT add metadata routing to *SearchCV (scikit-learn#27058)
1 parent e7ae63f commit 83a8015

File tree

7 files changed

+509
-174
lines changed

7 files changed

+509
-174
lines changed

doc/whats_new/v1.4.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,14 @@ Changelog
212212
is enabled and should be passed via the `params` parameter. :pr:`26896` by
213213
`Adrin Jalali`_.
214214

215+
- |Feature| :class:`~model_selection.GridSearchCV`,
216+
:class:`~model_selection.RandomizedSearchCV`,
217+
:class:`~model_selection.HalvingGridSearchCV`, and
218+
:class:`~model_selection.HalvingRandomSearchCV` now support metadata routing
219+
in their ``fit`` and ``score``, and route metadata to the underlying
220+
estimator's ``fit``, the CV splitter, and the scorer. :pr:`27058` by `Adrin
221+
Jalali`_.
222+
215223
- |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports
216224
Array API compatible inputs. :pr:`26855` by `Tim Head`_.
217225

sklearn/model_selection/_search.py

Lines changed: 134 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,21 @@
2727
from ..base import BaseEstimator, MetaEstimatorMixin, _fit_context, clone, is_classifier
2828
from ..exceptions import NotFittedError
2929
from ..metrics import check_scoring
30-
from ..metrics._scorer import _check_multimetric_scoring, get_scorer_names
31-
from ..utils import check_random_state
30+
from ..metrics._scorer import (
31+
_check_multimetric_scoring,
32+
_MultimetricScorer,
33+
get_scorer_names,
34+
)
35+
from ..utils import Bunch, check_random_state
3236
from ..utils._param_validation import HasMethods, Interval, StrOptions
3337
from ..utils._tags import _safe_tags
38+
from ..utils.metadata_routing import (
39+
MetadataRouter,
40+
MethodMapping,
41+
_raise_for_params,
42+
_routing_enabled,
43+
process_routing,
44+
)
3445
from ..utils.metaestimators import available_if
3546
from ..utils.parallel import Parallel, delayed
3647
from ..utils.random import sample_without_replacement
@@ -429,7 +440,7 @@ def _more_tags(self):
429440
},
430441
}
431442

432-
def score(self, X, y=None):
443+
def score(self, X, y=None, **params):
433444
"""Return the score on the given data, if the estimator has been refit.
434445
435446
This uses the score defined by ``scoring`` where provided, and the
@@ -446,6 +457,14 @@ def score(self, X, y=None):
446457
Target relative to X for classification or regression;
447458
None for unsupervised learning.
448459
460+
**params : dict
461+
Parameters to be passed to the underlying scorer(s).
462+
463+
..versionadded:: 1.4
464+
Only available if `enable_metadata_routing=True`. See
465+
:ref:`Metadata Routing User Guide <metadata_routing>` for more
466+
details.
467+
449468
Returns
450469
-------
451470
score : float
@@ -454,6 +473,14 @@ def score(self, X, y=None):
454473
"""
455474
_check_refit(self, "score")
456475
check_is_fitted(self)
476+
477+
_raise_for_params(params, self, "score")
478+
479+
if _routing_enabled():
480+
score_params = process_routing(self, "score", **params).scorer["score"]
481+
else:
482+
score_params = dict()
483+
457484
if self.scorer_ is None:
458485
raise ValueError(
459486
"No score function explicitly defined, "
@@ -465,10 +492,10 @@ def score(self, X, y=None):
465492
scorer = self.scorer_[self.refit]
466493
else:
467494
scorer = self.scorer_
468-
return scorer(self.best_estimator_, X, y)
495+
return scorer(self.best_estimator_, X, y, **score_params)
469496

470497
# callable
471-
score = self.scorer_(self.best_estimator_, X, y)
498+
score = self.scorer_(self.best_estimator_, X, y, **score_params)
472499
if self.multimetric_:
473500
score = score[self.refit]
474501
return score
@@ -754,11 +781,62 @@ def _select_best_index(refit, refit_metric, results):
754781
best_index = results[f"rank_test_{refit_metric}"].argmin()
755782
return best_index
756783

784+
def _get_scorers(self, convert_multimetric):
785+
"""Get the scorer(s) to be used.
786+
787+
This is used in ``fit`` and ``get_metadata_routing``.
788+
789+
Parameters
790+
----------
791+
convert_multimetric : bool
792+
Whether to convert a dict of scorers to a _MultimetricScorer. This
793+
is used in ``get_metadata_routing`` to include the routing info for
794+
multiple scorers.
795+
796+
Returns
797+
-------
798+
scorers, refit_metric
799+
"""
800+
refit_metric = "score"
801+
802+
if callable(self.scoring):
803+
scorers = self.scoring
804+
elif self.scoring is None or isinstance(self.scoring, str):
805+
scorers = check_scoring(self.estimator, self.scoring)
806+
else:
807+
scorers = _check_multimetric_scoring(self.estimator, self.scoring)
808+
self._check_refit_for_multimetric(scorers)
809+
refit_metric = self.refit
810+
if convert_multimetric and isinstance(scorers, dict):
811+
scorers = _MultimetricScorer(
812+
scorers=scorers, raise_exc=(self.error_score == "raise")
813+
)
814+
815+
return scorers, refit_metric
816+
817+
def _get_routed_params_for_fit(self, params):
818+
"""Get the parameters to be used for routing.
819+
820+
This is a method instead of a snippet in ``fit`` since it's used twice,
821+
here in ``fit``, and in ``HalvingRandomSearchCV.fit``.
822+
"""
823+
if _routing_enabled():
824+
routed_params = process_routing(self, "fit", **params)
825+
else:
826+
params = params.copy()
827+
groups = params.pop("groups", None)
828+
routed_params = Bunch(
829+
estimator=Bunch(fit=params),
830+
splitter=Bunch(split={"groups": groups}),
831+
scorer=Bunch(score={}),
832+
)
833+
return routed_params
834+
757835
@_fit_context(
758836
# *SearchCV.estimator is not validated yet
759837
prefer_skip_nested_validation=False
760838
)
761-
def fit(self, X, y=None, *, groups=None, **fit_params):
839+
def fit(self, X, y=None, **params):
762840
"""Run fit with all sets of parameters.
763841
764842
Parameters
@@ -773,13 +851,9 @@ def fit(self, X, y=None, *, groups=None, **fit_params):
773851
Target relative to X for classification or regression;
774852
None for unsupervised learning.
775853
776-
groups : array-like of shape (n_samples,), default=None
777-
Group labels for the samples used while splitting the dataset into
778-
train/test set. Only used in conjunction with a "Group" :term:`cv`
779-
instance (e.g., :class:`~sklearn.model_selection.GroupKFold`).
780-
781-
**fit_params : dict of str -> object
782-
Parameters passed to the `fit` method of the estimator.
854+
**params : dict of str -> object
855+
Parameters passed to the ``fit`` method of the estimator, the scorer,
856+
and the CV splitter.
783857
784858
If a fit parameter is an array-like whose length is equal to
785859
`num_samples` then it will be split across CV groups along with `X`
@@ -792,32 +866,27 @@ def fit(self, X, y=None, *, groups=None, **fit_params):
792866
Instance of fitted estimator.
793867
"""
794868
estimator = self.estimator
795-
refit_metric = "score"
869+
# Here we keep a dict of scorers as is, and only convert to a
870+
# _MultimetricScorer at a later stage. Issue:
871+
# https://github.com/scikit-learn/scikit-learn/issues/27001
872+
scorers, refit_metric = self._get_scorers(convert_multimetric=False)
796873

797-
if callable(self.scoring):
798-
scorers = self.scoring
799-
elif self.scoring is None or isinstance(self.scoring, str):
800-
scorers = check_scoring(self.estimator, self.scoring)
801-
else:
802-
scorers = _check_multimetric_scoring(self.estimator, self.scoring)
803-
self._check_refit_for_multimetric(scorers)
804-
refit_metric = self.refit
874+
X, y = indexable(X, y)
875+
params = _check_method_params(X, params=params)
805876

806-
X, y, groups = indexable(X, y, groups)
807-
fit_params = _check_method_params(X, params=fit_params)
877+
routed_params = self._get_routed_params_for_fit(params)
808878

809879
cv_orig = check_cv(self.cv, y, classifier=is_classifier(estimator))
810-
n_splits = cv_orig.get_n_splits(X, y, groups)
880+
n_splits = cv_orig.get_n_splits(X, y, **routed_params.splitter.split)
811881

812882
base_estimator = clone(self.estimator)
813883

814884
parallel = Parallel(n_jobs=self.n_jobs, pre_dispatch=self.pre_dispatch)
815885

816886
fit_and_score_kwargs = dict(
817887
scorer=scorers,
818-
fit_params=fit_params,
819-
# TODO(SLEP6): pass score params along
820-
score_params=None,
888+
fit_params=routed_params.estimator.fit,
889+
score_params=routed_params.scorer.score,
821890
return_train_score=self.return_train_score,
822891
return_n_test_samples=True,
823892
return_times=True,
@@ -857,7 +926,8 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
857926
**fit_and_score_kwargs,
858927
)
859928
for (cand_idx, parameters), (split_idx, (train, test)) in product(
860-
enumerate(candidate_params), enumerate(cv.split(X, y, groups))
929+
enumerate(candidate_params),
930+
enumerate(cv.split(X, y, **routed_params.splitter.split)),
861931
)
862932
)
863933

@@ -935,9 +1005,9 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
9351005

9361006
refit_start_time = time.time()
9371007
if y is not None:
938-
self.best_estimator_.fit(X, y, **fit_params)
1008+
self.best_estimator_.fit(X, y, **routed_params.estimator.fit)
9391009
else:
940-
self.best_estimator_.fit(X, **fit_params)
1010+
self.best_estimator_.fit(X, **routed_params.estimator.fit)
9411011
refit_end_time = time.time()
9421012
self.refit_time_ = refit_end_time - refit_start_time
9431013

@@ -1057,6 +1127,39 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10571127

10581128
return results
10591129

1130+
def get_metadata_routing(self):
1131+
"""Get metadata routing of this object.
1132+
1133+
Please check :ref:`User Guide <metadata_routing>` on how the routing
1134+
mechanism works.
1135+
1136+
.. versionadded:: 1.4
1137+
1138+
Returns
1139+
-------
1140+
routing : MetadataRouter
1141+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
1142+
routing information.
1143+
"""
1144+
router = MetadataRouter(owner=self.__class__.__name__)
1145+
router.add(
1146+
estimator=self.estimator,
1147+
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
1148+
)
1149+
1150+
scorer, _ = self._get_scorers(convert_multimetric=True)
1151+
router.add(
1152+
scorer=scorer,
1153+
method_mapping=MethodMapping()
1154+
.add(caller="score", callee="score")
1155+
.add(caller="fit", callee="score"),
1156+
)
1157+
router.add(
1158+
splitter=self.cv,
1159+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
1160+
)
1161+
return router
1162+
10601163

10611164
class GridSearchCV(BaseSearchCV):
10621165
"""Exhaustive search over specified parameter values for an estimator.

sklearn/model_selection/_search_successive_halving.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@ def __init__(self, *, base_cv, fraction, subsample_test, random_state):
2727
self.subsample_test = subsample_test
2828
self.random_state = random_state
2929

30-
def split(self, X, y, groups=None):
31-
for train_idx, test_idx in self.base_cv.split(X, y, groups):
30+
def split(self, X, y, **kwargs):
31+
for train_idx, test_idx in self.base_cv.split(X, y, **kwargs):
3232
train_idx = resample(
3333
train_idx,
3434
replace=False,
3535
random_state=self.random_state,
36-
n_samples=int(self.fraction * train_idx.shape[0]),
36+
n_samples=int(self.fraction * len(train_idx)),
3737
)
3838
if self.subsample_test:
3939
test_idx = resample(
4040
test_idx,
4141
replace=False,
4242
random_state=self.random_state,
43-
n_samples=int(self.fraction * test_idx.shape[0]),
43+
n_samples=int(self.fraction * len(test_idx)),
4444
)
4545
yield train_idx, test_idx
4646

@@ -123,7 +123,7 @@ def __init__(
123123
self.min_resources = min_resources
124124
self.aggressive_elimination = aggressive_elimination
125125

126-
def _check_input_parameters(self, X, y, groups):
126+
def _check_input_parameters(self, X, y, split_params):
127127
# We need to enforce that successive calls to cv.split() yield the same
128128
# splits: see https://github.com/scikit-learn/scikit-learn/issues/15149
129129
if not _yields_constant_splits(self._checked_cv_orig):
@@ -154,7 +154,7 @@ def _check_input_parameters(self, X, y, groups):
154154
self.min_resources_ = self.min_resources
155155
if self.min_resources_ in ("smallest", "exhaust"):
156156
if self.resource == "n_samples":
157-
n_splits = self._checked_cv_orig.get_n_splits(X, y, groups)
157+
n_splits = self._checked_cv_orig.get_n_splits(X, y, **split_params)
158158
# please see https://gph.is/1KjihQe for a justification
159159
magic_factor = 2
160160
self.min_resources_ = n_splits * magic_factor
@@ -215,7 +215,7 @@ def _select_best_index(refit, refit_metric, results):
215215
# Halving*SearchCV.estimator is not validated yet
216216
prefer_skip_nested_validation=False
217217
)
218-
def fit(self, X, y=None, groups=None, **fit_params):
218+
def fit(self, X, y=None, **params):
219219
"""Run fit with all sets of parameters.
220220
221221
Parameters
@@ -229,12 +229,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
229229
Target relative to X for classification or regression;
230230
None for unsupervised learning.
231231
232-
groups : array-like of shape (n_samples,), default=None
233-
Group labels for the samples used while splitting the dataset into
234-
train/test set. Only used in conjunction with a "Group" :term:`cv`
235-
instance (e.g., :class:`~sklearn.model_selection.GroupKFold`).
236-
237-
**fit_params : dict of string -> object
232+
**params : dict of string -> object
238233
Parameters passed to the ``fit`` method of the estimator.
239234
240235
Returns
@@ -246,15 +241,14 @@ def fit(self, X, y=None, groups=None, **fit_params):
246241
self.cv, y, classifier=is_classifier(self.estimator)
247242
)
248243

244+
routed_params = self._get_routed_params_for_fit(params)
249245
self._check_input_parameters(
250-
X=X,
251-
y=y,
252-
groups=groups,
246+
X=X, y=y, split_params=routed_params.splitter.split
253247
)
254248

255249
self._n_samples_orig = _num_samples(X)
256250

257-
super().fit(X, y=y, groups=groups, **fit_params)
251+
super().fit(X, y=y, **params)
258252

259253
# Set best_score_: BaseSearchCV does not set it, as refit is a callable
260254
self.best_score_ = self.cv_results_["mean_test_score"][self.best_index_]

0 commit comments

Comments
 (0)