Skip to content

Commit 61ba9bf

Browse files
authored
MNT SLEP6: metadata_routing.rst can now be tested (scikit-learn#27398)
1 parent 2800ed5 commit 61ba9bf

File tree

4 files changed

+20
-23
lines changed

4 files changed

+20
-23
lines changed

doc/conftest.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,6 @@ def pytest_runtest_setup(item):
145145
setup_preprocessing()
146146
elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
147147
setup_unsupervised_learning()
148-
elif fname.endswith("metadata_routing.rst"):
149-
# TODO: remove this once implemented
150-
# Skip metarouting because is it is not fully implemented yet
151-
raise SkipTest(
152-
"Skipping doctest for metadata_routing.rst because it "
153-
"is not fully implemented yet"
154-
)
155148

156149
rst_files_requiring_matplotlib = [
157150
"modules/partial_dependence.rst",

doc/metadata_routing.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ metadata called ``sample_weight``::
7676
... lr,
7777
... X,
7878
... y,
79-
... props={"sample_weight": my_weights, "groups": my_groups},
79+
... params={"sample_weight": my_weights, "groups": my_groups},
8080
... cv=GroupKFold(),
8181
... scoring=weighted_acc,
8282
... )
8383

8484
Note that in this example, ``my_weights`` is passed to both the scorer and
8585
:class:`~linear_model.LogisticRegressionCV`.
8686

87-
Error handling: if ``props={"sample_weigh": my_weights, ...}`` were passed
87+
Error handling: if ``params={"sample_weigh": my_weights, ...}`` were passed
8888
(note the typo), :func:`~model_selection.cross_validate` would raise an error,
8989
since ``sample_weigh`` was not requested by any of its underlying objects.
9090

@@ -110,7 +110,7 @@ that :func:`~model_selection.cross_validate` does not pass the weights along::
110110
... X,
111111
... y,
112112
... cv=GroupKFold(),
113-
... props={"sample_weight": my_weights, "groups": my_groups},
113+
... params={"sample_weight": my_weights, "groups": my_groups},
114114
... scoring=weighted_acc,
115115
... )
116116

@@ -142,7 +142,7 @@ instance is set and ``sample_weight`` is not routed to it::
142142
... X,
143143
... y,
144144
... cv=GroupKFold(),
145-
... props={"sample_weight": my_weights, "groups": my_groups},
145+
... params={"sample_weight": my_weights, "groups": my_groups},
146146
... scoring=weighted_acc,
147147
... )
148148

@@ -166,7 +166,7 @@ consumers. In this example, we pass ``scoring_weight`` to the scorer, and
166166
... X,
167167
... y,
168168
... cv=GroupKFold(),
169-
... props={
169+
... params={
170170
... "scoring_weight": my_weights,
171171
... "fitting_weight": my_other_weights,
172172
... "groups": my_groups,

sklearn/metrics/_scorer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,11 +547,12 @@ def get_metadata_routing(self):
547547
routing information.
548548
"""
549549
# This scorer doesn't do any validation or routing, it only exposes the
550-
# score requests to the parent object. This object behaves as a
551-
# consumer rather than a router.
552-
res = MetadataRequest(owner=self._estimator.__class__.__name__)
553-
res.score = get_routing_for_object(self._estimator).score
554-
return res
550+
# requests of the given estimator. This object behaves as a consumer
551+
# rather than a router. Ideally it only exposes the score requests to
552+
# the parent object; however, that requires computing the routing for
553+
# meta-estimators, which would be more time consuming than simply
554+
# returning the child object's requests.
555+
return get_routing_for_object(self._estimator)
555556

556557

557558
def _check_multimetric_scoring(estimator, scoring):

sklearn/metrics/tests/test_score_objects.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@
5555
from sklearn.neighbors import KNeighborsClassifier
5656
from sklearn.pipeline import make_pipeline
5757
from sklearn.svm import LinearSVC
58-
from sklearn.tests.metadata_routing_common import assert_request_is_empty
58+
from sklearn.tests.metadata_routing_common import (
59+
assert_request_equal,
60+
assert_request_is_empty,
61+
)
5962
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
6063
from sklearn.utils._testing import (
6164
assert_almost_equal,
@@ -1305,11 +1308,11 @@ def test_PassthroughScorer_metadata_request():
13051308
.set_score_request(sample_weight="alias")
13061309
.set_fit_request(sample_weight=True)
13071310
)
1308-
# test that _PassthroughScorer leaves everything other than `score` empty
1309-
assert_request_is_empty(scorer.get_metadata_routing(), exclude="score")
1310-
# test that _PassthroughScorer doesn't behave like a router and leaves
1311-
# the request as is.
1312-
assert scorer.get_metadata_routing().score.requests["sample_weight"] == "alias"
1311+
# Test that _PassthroughScorer doesn't change estimator's routing.
1312+
assert_request_equal(
1313+
scorer.get_metadata_routing(),
1314+
{"fit": {"sample_weight": True}, "score": {"sample_weight": "alias"}},
1315+
)
13131316

13141317

13151318
@pytest.mark.usefixtures("enable_slep006")

0 commit comments

Comments
 (0)