Skip to content

Commit 4e67902

Browse files
CLN clean up some repeated code related to SLEP006 (scikit-learn#26836)
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent 150ff34 commit 4e67902

File tree

8 files changed

+115
-109
lines changed

8 files changed

+115
-109
lines changed

sklearn/linear_model/_logistic.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..utils.metadata_routing import (
3939
MetadataRouter,
4040
MethodMapping,
41+
_raise_for_params,
4142
_routing_enabled,
4243
process_routing,
4344
)
@@ -1791,11 +1792,7 @@ def fit(self, X, y, sample_weight=None, **params):
17911792
self : object
17921793
Fitted LogisticRegressionCV estimator.
17931794
"""
1794-
if params and not _routing_enabled():
1795-
raise ValueError(
1796-
"params is only supported if enable_metadata_routing=True."
1797-
" See the User Guide for more information."
1798-
)
1795+
_raise_for_params(params, self, "fit")
17991796

18001797
solver = _check_solver(self.solver, self.penalty, self.dual)
18011798

@@ -2146,12 +2143,7 @@ def score(self, X, y, sample_weight=None, **score_params):
21462143
score : float
21472144
Score of self.predict(X) w.r.t. y.
21482145
"""
2149-
if score_params and not _routing_enabled():
2150-
raise ValueError(
2151-
"score_params is only supported if enable_metadata_routing=True."
2152-
" See the User Guide for more information."
2153-
" https://scikit-learn.org/stable/metadata_routing.html"
2154-
)
2146+
_raise_for_params(score_params, self, "score")
21552147

21562148
scoring = self._get_scorer()
21572149
if _routing_enabled():

sklearn/linear_model/tests/test_logistic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2167,7 +2167,7 @@ def test_passing_params_without_enabling_metadata_routing():
21672167
are passed while not supported when `enable_metadata_routing=False`."""
21682168
X, y = make_classification(n_samples=10, random_state=0)
21692169
lr_cv = LogisticRegressionCV()
2170-
msg = "params is only supported if enable_metadata_routing=True"
2170+
msg = "is only supported if enable_metadata_routing=True"
21712171

21722172
with config_context(enable_metadata_routing=False):
21732173
params = {"extra_param": 1.0}

sklearn/metrics/_scorer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MetadataRequest,
3636
MetadataRouter,
3737
_MetadataRequester,
38+
_raise_for_params,
3839
_routing_enabled,
3940
get_routing_for_object,
4041
process_routing,
@@ -253,11 +254,7 @@ def __call__(self, estimator, X, y_true, sample_weight=None, **kwargs):
253254
score : float
254255
Score function applied to prediction of estimator on X.
255256
"""
256-
if kwargs and not _routing_enabled():
257-
raise ValueError(
258-
"kwargs is only supported if enable_metadata_routing=True. See"
259-
" the User Guide for more information."
260-
)
257+
_raise_for_params(kwargs, self, None)
261258

262259
_kwargs = copy.deepcopy(kwargs)
263260
if sample_weight is not None:

sklearn/metrics/tests/test_score_objects.py

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,97 +1197,95 @@ def test_scorer_no_op_multiclass_select_proba():
11971197
scorer(lr, X_test, y_test)
11981198

11991199

1200+
@pytest.mark.usefixtures("enable_slep006")
12001201
@pytest.mark.parametrize("name", get_scorer_names(), ids=get_scorer_names())
12011202
def test_scorer_metadata_request(name):
12021203
"""Testing metadata requests for scorers.
12031204
12041205
This test checks many small things in a large test, to reduce the
12051206
boilerplate required for each section.
12061207
"""
1207-
with config_context(enable_metadata_routing=True):
1208-
# Make sure they expose the routing methods.
1209-
scorer = get_scorer(name)
1210-
assert hasattr(scorer, "set_score_request")
1211-
assert hasattr(scorer, "get_metadata_routing")
1212-
1213-
# Check that by default no metadata is requested.
1214-
assert_request_is_empty(scorer.get_metadata_routing())
1215-
1216-
weighted_scorer = scorer.set_score_request(sample_weight=True)
1217-
# set_score_request should mutate the instance, rather than returning a
1218-
# new instance
1219-
assert weighted_scorer is scorer
1220-
1221-
# make sure the scorer doesn't request anything on methods other than
1222-
# `score`, and that the requested value on `score` is correct.
1223-
assert_request_is_empty(weighted_scorer.get_metadata_routing(), exclude="score")
1224-
assert (
1225-
weighted_scorer.get_metadata_routing().score.requests["sample_weight"]
1226-
is True
1227-
)
1208+
# Make sure they expose the routing methods.
1209+
scorer = get_scorer(name)
1210+
assert hasattr(scorer, "set_score_request")
1211+
assert hasattr(scorer, "get_metadata_routing")
1212+
1213+
# Check that by default no metadata is requested.
1214+
assert_request_is_empty(scorer.get_metadata_routing())
1215+
1216+
weighted_scorer = scorer.set_score_request(sample_weight=True)
1217+
# set_score_request should mutate the instance, rather than returning a
1218+
# new instance
1219+
assert weighted_scorer is scorer
1220+
1221+
# make sure the scorer doesn't request anything on methods other than
1222+
# `score`, and that the requested value on `score` is correct.
1223+
assert_request_is_empty(weighted_scorer.get_metadata_routing(), exclude="score")
1224+
assert (
1225+
weighted_scorer.get_metadata_routing().score.requests["sample_weight"] is True
1226+
)
12281227

1229-
# make sure putting the scorer in a router doesn't request anything by
1230-
# default
1231-
router = MetadataRouter(owner="test").add(
1232-
method_mapping="score", scorer=get_scorer(name)
1233-
)
1234-
# make sure `sample_weight` is refused if passed.
1235-
with pytest.raises(TypeError, match="got unexpected argument"):
1236-
router.validate_metadata(params={"sample_weight": 1}, method="score")
1237-
# make sure `sample_weight` is not routed even if passed.
1238-
routed_params = router.route_params(params={"sample_weight": 1}, caller="score")
1239-
assert not routed_params.scorer.score
1240-
1241-
# make sure putting weighted_scorer in a router requests sample_weight
1242-
router = MetadataRouter(owner="test").add(
1243-
scorer=weighted_scorer, method_mapping="score"
1244-
)
1228+
# make sure putting the scorer in a router doesn't request anything by
1229+
# default
1230+
router = MetadataRouter(owner="test").add(
1231+
method_mapping="score", scorer=get_scorer(name)
1232+
)
1233+
# make sure `sample_weight` is refused if passed.
1234+
with pytest.raises(TypeError, match="got unexpected argument"):
12451235
router.validate_metadata(params={"sample_weight": 1}, method="score")
1246-
routed_params = router.route_params(params={"sample_weight": 1}, caller="score")
1247-
assert list(routed_params.scorer.score.keys()) == ["sample_weight"]
1236+
# make sure `sample_weight` is not routed even if passed.
1237+
routed_params = router.route_params(params={"sample_weight": 1}, caller="score")
1238+
assert not routed_params.scorer.score
1239+
1240+
# make sure putting weighted_scorer in a router requests sample_weight
1241+
router = MetadataRouter(owner="test").add(
1242+
scorer=weighted_scorer, method_mapping="score"
1243+
)
1244+
router.validate_metadata(params={"sample_weight": 1}, method="score")
1245+
routed_params = router.route_params(params={"sample_weight": 1}, caller="score")
1246+
assert list(routed_params.scorer.score.keys()) == ["sample_weight"]
12481247

12491248

1249+
@pytest.mark.usefixtures("enable_slep006")
12501250
def test_metadata_kwarg_conflict():
12511251
"""This test makes sure the right warning is raised if the user passes
12521252
some metadata both as a constructor to make_scorer, and during __call__.
12531253
"""
1254-
with config_context(enable_metadata_routing=True):
1255-
X, y = make_classification(
1256-
n_classes=3, n_informative=3, n_samples=20, random_state=0
1257-
)
1258-
lr = LogisticRegression().fit(X, y)
1254+
X, y = make_classification(
1255+
n_classes=3, n_informative=3, n_samples=20, random_state=0
1256+
)
1257+
lr = LogisticRegression().fit(X, y)
12591258

1260-
scorer = make_scorer(
1261-
roc_auc_score,
1262-
needs_proba=True,
1263-
multi_class="ovo",
1264-
labels=lr.classes_,
1265-
)
1266-
with pytest.warns(UserWarning, match="already set as kwargs"):
1267-
scorer.set_score_request(labels=True)
1259+
scorer = make_scorer(
1260+
roc_auc_score,
1261+
needs_proba=True,
1262+
multi_class="ovo",
1263+
labels=lr.classes_,
1264+
)
1265+
with pytest.warns(UserWarning, match="already set as kwargs"):
1266+
scorer.set_score_request(labels=True)
12681267

1269-
with config_context(enable_metadata_routing=True):
1270-
with pytest.warns(UserWarning, match="There is an overlap"):
1271-
scorer(lr, X, y, labels=lr.classes_)
1268+
with pytest.warns(UserWarning, match="There is an overlap"):
1269+
scorer(lr, X, y, labels=lr.classes_)
12721270

12731271

1272+
@pytest.mark.usefixtures("enable_slep006")
12741273
def test_PassthroughScorer_metadata_request():
12751274
"""Test that _PassthroughScorer properly routes metadata.
12761275
12771276
_PassthroughScorer should behave like a consumer, mirroring whatever is the
12781277
underlying score method.
12791278
"""
1280-
with config_context(enable_metadata_routing=True):
1281-
scorer = _PassthroughScorer(
1282-
estimator=LinearSVC()
1283-
.set_score_request(sample_weight="alias")
1284-
.set_fit_request(sample_weight=True)
1285-
)
1286-
# test that _PassthroughScorer leaves everything other than `score` empty
1287-
assert_request_is_empty(scorer.get_metadata_routing(), exclude="score")
1288-
# test that _PassthroughScorer doesn't behave like a router and leaves
1289-
# the request as is.
1290-
assert scorer.get_metadata_routing().score.requests["sample_weight"] == "alias"
1279+
scorer = _PassthroughScorer(
1280+
estimator=LinearSVC()
1281+
.set_score_request(sample_weight="alias")
1282+
.set_fit_request(sample_weight=True)
1283+
)
1284+
# test that _PassthroughScorer leaves everything other than `score` empty
1285+
assert_request_is_empty(scorer.get_metadata_routing(), exclude="score")
1286+
# test that _PassthroughScorer doesn't behave like a router and leaves
1287+
# the request as is.
1288+
assert scorer.get_metadata_routing().score.requests["sample_weight"] == "alias"
12911289

12921290

12931291
def test_multimetric_scoring_metadata_routing():
@@ -1344,5 +1342,7 @@ def score(y_true, y_pred, param=None):
13441342
clf = DecisionTreeClassifier().fit(X, y)
13451343
scorer = make_scorer(score)
13461344
with config_context(enable_metadata_routing=False):
1347-
with pytest.raises(ValueError, match="kwargs is only supported if"):
1345+
with pytest.raises(
1346+
ValueError, match="is only supported if enable_metadata_routing=True"
1347+
):
13481348
scorer(clf, X, y, param="blah")

sklearn/multioutput.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .utils.metadata_routing import (
3737
MetadataRouter,
3838
MethodMapping,
39+
_raise_for_params,
3940
_routing_enabled,
4041
process_routing,
4142
)
@@ -148,11 +149,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, **partial_fit_para
148149
self : object
149150
Returns a fitted instance.
150151
"""
151-
if partial_fit_params and not _routing_enabled():
152-
raise ValueError(
153-
"partial_fit_params is only supported if enable_metadata_routing=True."
154-
" See the User Guide for more information."
155-
)
152+
_raise_for_params(partial_fit_params, self, "partial_fit")
156153

157154
first_time = not hasattr(self, "estimators_")
158155

@@ -919,11 +916,7 @@ def fit(self, X, Y, **fit_params):
919916
self : object
920917
Class instance.
921918
"""
922-
if fit_params and not _routing_enabled():
923-
raise ValueError(
924-
"fit_params is only supported if enable_metadata_routing=True. "
925-
"See the User Guide for more information."
926-
)
919+
_raise_for_params(fit_params, self, "fit")
927920

928921
super().fit(X, Y, **fit_params)
929922
self.classes_ = [

sklearn/pipeline.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .utils.metadata_routing import (
2828
MetadataRouter,
2929
MethodMapping,
30+
_raise_for_params,
3031
_routing_enabled,
3132
process_routing,
3233
)
@@ -742,11 +743,7 @@ def decision_function(self, X, **params):
742743
y_score : ndarray of shape (n_samples, n_classes)
743744
Result of calling `decision_function` on the final estimator.
744745
"""
745-
if params and not _routing_enabled():
746-
raise ValueError(
747-
"params is only supported if enable_metadata_routing=True."
748-
" See the User Guide for more information."
749-
)
746+
_raise_for_params(params, self, "decision_function")
750747

751748
# not branching here since params is only available if
752749
# enable_metadata_routing=True
@@ -881,11 +878,7 @@ def transform(self, X, **params):
881878
Xt : ndarray of shape (n_samples, n_transformed_features)
882879
Transformed data.
883880
"""
884-
if not _routing_enabled() and params:
885-
raise ValueError(
886-
"params is only supported if enable_metadata_routing=True."
887-
" See the User Guide for more information."
888-
)
881+
_raise_for_params(params, self, "transform")
889882

890883
# not branching here since params is only available if
891884
# enable_metadata_routing=True
@@ -928,11 +921,7 @@ def inverse_transform(self, Xt, **params):
928921
Inverse transformed data, that is, data in the original feature
929922
space.
930923
"""
931-
if not _routing_enabled() and params:
932-
raise ValueError(
933-
"params is only supported if enable_metadata_routing=True. See"
934-
" the User Guide for more information."
935-
)
924+
_raise_for_params(params, self, "inverse_transform")
936925

937926
# we don't have to branch here, since params is only non-empty if
938927
# enable_metadata_routing=True.

sklearn/utils/_metadata_requests.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,40 @@ def _routing_enabled():
127127
return get_config().get("enable_metadata_routing", False)
128128

129129

130+
def _raise_for_params(params, owner, method):
131+
"""Raise an error if metadata routing is not enabled and params are passed.
132+
133+
.. versionadded:: 1.4
134+
135+
Parameters
136+
----------
137+
params : dict
138+
The metadata passed to a method.
139+
140+
owner : object
141+
The object to which the method belongs.
142+
143+
method : str
144+
The name of the method, e.g. "fit".
145+
146+
Raises
147+
------
148+
ValueError
149+
If metadata routing is not enabled and params are passed.
150+
"""
151+
caller = (
152+
f"{owner.__class__.__name__}.{method}" if method else owner.__class__.__name__
153+
)
154+
if not _routing_enabled() and params:
155+
raise ValueError(
156+
f"Passing extra keyword arguments to {caller} is only supported if"
157+
" enable_metadata_routing=True, which you can set using"
158+
" `sklearn.set_config`. See the User Guide"
159+
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
160+
" details."
161+
)
162+
163+
130164
# Request values
131165
# ==============
132166
# Each request value needs to be one of the following values, or an alias.

sklearn/utils/metadata_routing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from ._metadata_requests import process_routing # noqa
1717
from ._metadata_requests import _MetadataRequester # noqa
1818
from ._metadata_requests import _routing_enabled # noqa
19+
from ._metadata_requests import _raise_for_params # noqa

0 commit comments

Comments
 (0)