27
27
from ..base import BaseEstimator , MetaEstimatorMixin , _fit_context , clone , is_classifier
28
28
from ..exceptions import NotFittedError
29
29
from ..metrics import check_scoring
30
- from ..metrics ._scorer import _check_multimetric_scoring , get_scorer_names
31
- from ..utils import check_random_state
30
+ from ..metrics ._scorer import (
31
+ _check_multimetric_scoring ,
32
+ _MultimetricScorer ,
33
+ get_scorer_names ,
34
+ )
35
+ from ..utils import Bunch , check_random_state
32
36
from ..utils ._param_validation import HasMethods , Interval , StrOptions
33
37
from ..utils ._tags import _safe_tags
38
+ from ..utils .metadata_routing import (
39
+ MetadataRouter ,
40
+ MethodMapping ,
41
+ _raise_for_params ,
42
+ _routing_enabled ,
43
+ process_routing ,
44
+ )
34
45
from ..utils .metaestimators import available_if
35
46
from ..utils .parallel import Parallel , delayed
36
47
from ..utils .random import sample_without_replacement
@@ -429,7 +440,7 @@ def _more_tags(self):
429
440
},
430
441
}
431
442
432
- def score (self , X , y = None ):
443
+ def score (self , X , y = None , ** params ):
433
444
"""Return the score on the given data, if the estimator has been refit.
434
445
435
446
This uses the score defined by ``scoring`` where provided, and the
@@ -446,6 +457,14 @@ def score(self, X, y=None):
446
457
Target relative to X for classification or regression;
447
458
None for unsupervised learning.
448
459
460
+ **params : dict
461
+ Parameters to be passed to the underlying scorer(s).
462
+
463
+ ..versionadded:: 1.4
464
+ Only available if `enable_metadata_routing=True`. See
465
+ :ref:`Metadata Routing User Guide <metadata_routing>` for more
466
+ details.
467
+
449
468
Returns
450
469
-------
451
470
score : float
@@ -454,6 +473,14 @@ def score(self, X, y=None):
454
473
"""
455
474
_check_refit (self , "score" )
456
475
check_is_fitted (self )
476
+
477
+ _raise_for_params (params , self , "score" )
478
+
479
+ if _routing_enabled ():
480
+ score_params = process_routing (self , "score" , ** params ).scorer ["score" ]
481
+ else :
482
+ score_params = dict ()
483
+
457
484
if self .scorer_ is None :
458
485
raise ValueError (
459
486
"No score function explicitly defined, "
@@ -465,10 +492,10 @@ def score(self, X, y=None):
465
492
scorer = self .scorer_ [self .refit ]
466
493
else :
467
494
scorer = self .scorer_
468
- return scorer (self .best_estimator_ , X , y )
495
+ return scorer (self .best_estimator_ , X , y , ** score_params )
469
496
470
497
# callable
471
- score = self .scorer_ (self .best_estimator_ , X , y )
498
+ score = self .scorer_ (self .best_estimator_ , X , y , ** score_params )
472
499
if self .multimetric_ :
473
500
score = score [self .refit ]
474
501
return score
@@ -754,11 +781,62 @@ def _select_best_index(refit, refit_metric, results):
754
781
best_index = results [f"rank_test_{ refit_metric } " ].argmin ()
755
782
return best_index
756
783
784
+ def _get_scorers (self , convert_multimetric ):
785
+ """Get the scorer(s) to be used.
786
+
787
+ This is used in ``fit`` and ``get_metadata_routing``.
788
+
789
+ Parameters
790
+ ----------
791
+ convert_multimetric : bool
792
+ Whether to convert a dict of scorers to a _MultimetricScorer. This
793
+ is used in ``get_metadata_routing`` to include the routing info for
794
+ multiple scorers.
795
+
796
+ Returns
797
+ -------
798
+ scorers, refit_metric
799
+ """
800
+ refit_metric = "score"
801
+
802
+ if callable (self .scoring ):
803
+ scorers = self .scoring
804
+ elif self .scoring is None or isinstance (self .scoring , str ):
805
+ scorers = check_scoring (self .estimator , self .scoring )
806
+ else :
807
+ scorers = _check_multimetric_scoring (self .estimator , self .scoring )
808
+ self ._check_refit_for_multimetric (scorers )
809
+ refit_metric = self .refit
810
+ if convert_multimetric and isinstance (scorers , dict ):
811
+ scorers = _MultimetricScorer (
812
+ scorers = scorers , raise_exc = (self .error_score == "raise" )
813
+ )
814
+
815
+ return scorers , refit_metric
816
+
817
+ def _get_routed_params_for_fit (self , params ):
818
+ """Get the parameters to be used for routing.
819
+
820
+ This is a method instead of a snippet in ``fit`` since it's used twice,
821
+ here in ``fit``, and in ``HalvingRandomSearchCV.fit``.
822
+ """
823
+ if _routing_enabled ():
824
+ routed_params = process_routing (self , "fit" , ** params )
825
+ else :
826
+ params = params .copy ()
827
+ groups = params .pop ("groups" , None )
828
+ routed_params = Bunch (
829
+ estimator = Bunch (fit = params ),
830
+ splitter = Bunch (split = {"groups" : groups }),
831
+ scorer = Bunch (score = {}),
832
+ )
833
+ return routed_params
834
+
757
835
@_fit_context (
758
836
# *SearchCV.estimator is not validated yet
759
837
prefer_skip_nested_validation = False
760
838
)
761
- def fit (self , X , y = None , * , groups = None , ** fit_params ):
839
+ def fit (self , X , y = None , ** params ):
762
840
"""Run fit with all sets of parameters.
763
841
764
842
Parameters
@@ -773,13 +851,9 @@ def fit(self, X, y=None, *, groups=None, **fit_params):
773
851
Target relative to X for classification or regression;
774
852
None for unsupervised learning.
775
853
776
- groups : array-like of shape (n_samples,), default=None
777
- Group labels for the samples used while splitting the dataset into
778
- train/test set. Only used in conjunction with a "Group" :term:`cv`
779
- instance (e.g., :class:`~sklearn.model_selection.GroupKFold`).
780
-
781
- **fit_params : dict of str -> object
782
- Parameters passed to the `fit` method of the estimator.
854
+ **params : dict of str -> object
855
+ Parameters passed to the ``fit`` method of the estimator, the scorer,
856
+ and the CV splitter.
783
857
784
858
If a fit parameter is an array-like whose length is equal to
785
859
`num_samples` then it will be split across CV groups along with `X`
@@ -792,32 +866,27 @@ def fit(self, X, y=None, *, groups=None, **fit_params):
792
866
Instance of fitted estimator.
793
867
"""
794
868
estimator = self .estimator
795
- refit_metric = "score"
869
+ # Here we keep a dict of scorers as is, and only convert to a
870
+ # _MultimetricScorer at a later stage. Issue:
871
+ # https://github.com/scikit-learn/scikit-learn/issues/27001
872
+ scorers , refit_metric = self ._get_scorers (convert_multimetric = False )
796
873
797
- if callable (self .scoring ):
798
- scorers = self .scoring
799
- elif self .scoring is None or isinstance (self .scoring , str ):
800
- scorers = check_scoring (self .estimator , self .scoring )
801
- else :
802
- scorers = _check_multimetric_scoring (self .estimator , self .scoring )
803
- self ._check_refit_for_multimetric (scorers )
804
- refit_metric = self .refit
874
+ X , y = indexable (X , y )
875
+ params = _check_method_params (X , params = params )
805
876
806
- X , y , groups = indexable (X , y , groups )
807
- fit_params = _check_method_params (X , params = fit_params )
877
+ routed_params = self ._get_routed_params_for_fit (params )
808
878
809
879
cv_orig = check_cv (self .cv , y , classifier = is_classifier (estimator ))
810
- n_splits = cv_orig .get_n_splits (X , y , groups )
880
+ n_splits = cv_orig .get_n_splits (X , y , ** routed_params . splitter . split )
811
881
812
882
base_estimator = clone (self .estimator )
813
883
814
884
parallel = Parallel (n_jobs = self .n_jobs , pre_dispatch = self .pre_dispatch )
815
885
816
886
fit_and_score_kwargs = dict (
817
887
scorer = scorers ,
818
- fit_params = fit_params ,
819
- # TODO(SLEP6): pass score params along
820
- score_params = None ,
888
+ fit_params = routed_params .estimator .fit ,
889
+ score_params = routed_params .scorer .score ,
821
890
return_train_score = self .return_train_score ,
822
891
return_n_test_samples = True ,
823
892
return_times = True ,
@@ -857,7 +926,8 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
857
926
** fit_and_score_kwargs ,
858
927
)
859
928
for (cand_idx , parameters ), (split_idx , (train , test )) in product (
860
- enumerate (candidate_params ), enumerate (cv .split (X , y , groups ))
929
+ enumerate (candidate_params ),
930
+ enumerate (cv .split (X , y , ** routed_params .splitter .split )),
861
931
)
862
932
)
863
933
@@ -935,9 +1005,9 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
935
1005
936
1006
refit_start_time = time .time ()
937
1007
if y is not None :
938
- self .best_estimator_ .fit (X , y , ** fit_params )
1008
+ self .best_estimator_ .fit (X , y , ** routed_params . estimator . fit )
939
1009
else :
940
- self .best_estimator_ .fit (X , ** fit_params )
1010
+ self .best_estimator_ .fit (X , ** routed_params . estimator . fit )
941
1011
refit_end_time = time .time ()
942
1012
self .refit_time_ = refit_end_time - refit_start_time
943
1013
@@ -1057,6 +1127,39 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
1057
1127
1058
1128
return results
1059
1129
1130
+ def get_metadata_routing (self ):
1131
+ """Get metadata routing of this object.
1132
+
1133
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
1134
+ mechanism works.
1135
+
1136
+ .. versionadded:: 1.4
1137
+
1138
+ Returns
1139
+ -------
1140
+ routing : MetadataRouter
1141
+ A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
1142
+ routing information.
1143
+ """
1144
+ router = MetadataRouter (owner = self .__class__ .__name__ )
1145
+ router .add (
1146
+ estimator = self .estimator ,
1147
+ method_mapping = MethodMapping ().add (caller = "fit" , callee = "fit" ),
1148
+ )
1149
+
1150
+ scorer , _ = self ._get_scorers (convert_multimetric = True )
1151
+ router .add (
1152
+ scorer = scorer ,
1153
+ method_mapping = MethodMapping ()
1154
+ .add (caller = "score" , callee = "score" )
1155
+ .add (caller = "fit" , callee = "score" ),
1156
+ )
1157
+ router .add (
1158
+ splitter = self .cv ,
1159
+ method_mapping = MethodMapping ().add (caller = "fit" , callee = "split" ),
1160
+ )
1161
+ return router
1162
+
1060
1163
1061
1164
class GridSearchCV (BaseSearchCV ):
1062
1165
"""Exhaustive search over specified parameter values for an estimator.
0 commit comments