Skip to content

Commit afee65a

Browse files
authored
FEAT SLEP006 permutation_test_score to support metadata routing (scikit-learn#29266)
1 parent 20c7bd0 commit afee65a

File tree

4 files changed

+113
-13
lines changed

4 files changed

+113
-13
lines changed

doc/metadata_routing.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ Meta-estimators and functions supporting metadata routing:
301301
- :class:`sklearn.model_selection.HalvingGridSearchCV`
302302
- :class:`sklearn.model_selection.HalvingRandomSearchCV`
303303
- :class:`sklearn.model_selection.RandomizedSearchCV`
304+
- :class:`sklearn.model_selection.permutation_test_score`
304305
- :func:`sklearn.model_selection.cross_validate`
305306
- :func:`sklearn.model_selection.cross_val_score`
306307
- :func:`sklearn.model_selection.cross_val_predict`
@@ -324,4 +325,3 @@ Meta-estimators and tools not supporting metadata routing yet:
324325
- :class:`sklearn.feature_selection.RFE`
325326
- :class:`sklearn.feature_selection.RFECV`
326327
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
327-
- :class:`sklearn.model_selection.permutation_test_score`

doc/whats_new/v1.6.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ more details.
8989
passed to the underlying estimators via their respective methods.
9090
:pr:`28494` by :user:`Adam Li <adam2392>`.
9191

92+
- |Feature| :func:`model_selection.permutation_test_score` now supports metadata routing
93+
for the `fit` method of its estimator and for its underlying CV splitter and scorer.
94+
:pr:`29266` by :user:`Adam Li <adam2392>`.
95+
9296
Dropping support for building with setuptools
9397
---------------------------------------------
9498

sklearn/model_selection/_validation.py

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,7 @@ def _check_is_permutation(indices, n_samples):
14931493
"verbose": ["verbose"],
14941494
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
14951495
"fit_params": [dict, None],
1496+
"params": [dict, None],
14961497
},
14971498
prefer_skip_nested_validation=False, # estimator is not validated yet
14981499
)
@@ -1509,6 +1510,7 @@ def permutation_test_score(
15091510
verbose=0,
15101511
scoring=None,
15111512
fit_params=None,
1513+
params=None,
15121514
):
15131515
"""Evaluate the significance of a cross-validated score with permutations.
15141516
@@ -1548,6 +1550,13 @@ def permutation_test_score(
15481550
cross-validator uses them for grouping the samples while splitting
15491551
the dataset into train/test set.
15501552
1553+
.. versionchanged:: 1.6
1554+
``groups`` can only be passed if metadata routing is not enabled
1555+
via ``sklearn.set_config(enable_metadata_routing=True)``. When routing
1556+
is enabled, pass ``groups`` alongside other metadata via the ``params``
1557+
argument instead. E.g.:
1558+
``permutation_test_score(..., params={'groups': groups})``.
1559+
15511560
cv : int, cross-validation generator or an iterable, default=None
15521561
Determines the cross-validation splitting strategy.
15531562
Possible inputs for cv are:
@@ -1594,7 +1603,24 @@ def permutation_test_score(
15941603
fit_params : dict, default=None
15951604
Parameters to pass to the fit method of the estimator.
15961605
1597-
.. versionadded:: 0.24
1606+
.. deprecated:: 1.6
1607+
This parameter is deprecated and will be removed in version 1.6. Use
1608+
``params`` instead.
1609+
1610+
params : dict, default=None
1611+
Parameters to pass to the `fit` method of the estimator, the scorer
1612+
and the cv splitter.
1613+
1614+
- If `enable_metadata_routing=False` (default):
1615+
Parameters directly passed to the `fit` method of the estimator.
1616+
1617+
- If `enable_metadata_routing=True`:
1618+
Parameters safely routed to the `fit` method of the estimator,
1619+
`cv` object and `scorer`.
1620+
See :ref:`Metadata Routing User Guide <metadata_routing>` for more
1621+
details.
1622+
1623+
.. versionadded:: 1.6
15981624
15991625
Returns
16001626
-------
@@ -1643,26 +1669,86 @@ def permutation_test_score(
16431669
>>> print(f"P-value: {pvalue:.3f}")
16441670
P-value: 0.010
16451671
"""
1672+
params = _check_params_groups_deprecation(fit_params, params, groups, "1.8")
1673+
16461674
X, y, groups = indexable(X, y, groups)
16471675

16481676
cv = check_cv(cv, y, classifier=is_classifier(estimator))
16491677
scorer = check_scoring(estimator, scoring=scoring)
16501678
random_state = check_random_state(random_state)
16511679

1680+
if _routing_enabled():
1681+
router = (
1682+
MetadataRouter(owner="permutation_test_score")
1683+
.add(
1684+
estimator=estimator,
1685+
# TODO(SLEP6): also pass metadata to the predict method for
1686+
# scoring?
1687+
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
1688+
)
1689+
.add(
1690+
splitter=cv,
1691+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
1692+
)
1693+
.add(
1694+
scorer=scorer,
1695+
method_mapping=MethodMapping().add(caller="fit", callee="score"),
1696+
)
1697+
)
1698+
1699+
try:
1700+
routed_params = process_routing(router, "fit", **params)
1701+
except UnsetMetadataPassedError as e:
1702+
# The default exception would mention `fit` since in the above
1703+
# `process_routing` code, we pass `fit` as the caller. However,
1704+
# the user is not calling `fit` directly, so we change the message
1705+
# to make it more suitable for this case.
1706+
unrequested_params = sorted(e.unrequested_params)
1707+
raise UnsetMetadataPassedError(
1708+
message=(
1709+
f"{unrequested_params} are passed to `permutation_test_score`"
1710+
" but are not explicitly set as requested or not requested"
1711+
" for permutation_test_score's"
1712+
f" estimator: {estimator.__class__.__name__}. Call"
1713+
" `.set_fit_request({{metadata}}=True)` on the estimator for"
1714+
f" each metadata in {unrequested_params} that you"
1715+
" want to use and `metadata=False` for not using it. See the"
1716+
" Metadata Routing User guide"
1717+
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
1718+
" information."
1719+
),
1720+
unrequested_params=e.unrequested_params,
1721+
routed_params=e.routed_params,
1722+
)
1723+
1724+
else:
1725+
routed_params = Bunch()
1726+
routed_params.estimator = Bunch(fit=params)
1727+
routed_params.splitter = Bunch(split={"groups": groups})
1728+
routed_params.scorer = Bunch(score={})
1729+
16521730
# We clone the estimator to make sure that all the folds are
16531731
# independent, and that it is pickle-able.
16541732
score = _permutation_test_score(
1655-
clone(estimator), X, y, groups, cv, scorer, fit_params=fit_params
1733+
clone(estimator),
1734+
X,
1735+
y,
1736+
cv,
1737+
scorer,
1738+
split_params=routed_params.splitter.split,
1739+
fit_params=routed_params.estimator.fit,
1740+
score_params=routed_params.scorer.score,
16561741
)
16571742
permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
16581743
delayed(_permutation_test_score)(
16591744
clone(estimator),
16601745
X,
16611746
_shuffle(y, groups, random_state),
1662-
groups,
16631747
cv,
16641748
scorer,
1665-
fit_params=fit_params,
1749+
split_params=routed_params.splitter.split,
1750+
fit_params=routed_params.estimator.fit,
1751+
score_params=routed_params.scorer.score,
16661752
)
16671753
for _ in range(n_permutations)
16681754
)
@@ -1671,17 +1757,22 @@ def permutation_test_score(
16711757
return score, permutation_scores, pvalue
16721758

16731759

1674-
def _permutation_test_score(estimator, X, y, groups, cv, scorer, fit_params):
1760+
def _permutation_test_score(
1761+
estimator, X, y, cv, scorer, split_params, fit_params, score_params
1762+
):
16751763
"""Auxiliary function for permutation_test_score"""
16761764
# Adjust length of sample weights
16771765
fit_params = fit_params if fit_params is not None else {}
1766+
score_params = score_params if score_params is not None else {}
1767+
16781768
avg_score = []
1679-
for train, test in cv.split(X, y, groups):
1769+
for train, test in cv.split(X, y, **split_params):
16801770
X_train, y_train = _safe_split(estimator, X, y, train)
16811771
X_test, y_test = _safe_split(estimator, X, y, test, train)
1682-
fit_params = _check_method_params(X, params=fit_params, indices=train)
1683-
estimator.fit(X_train, y_train, **fit_params)
1684-
avg_score.append(scorer(estimator, X_test, y_test))
1772+
fit_params_train = _check_method_params(X, params=fit_params, indices=train)
1773+
score_params_test = _check_method_params(X, params=score_params, indices=test)
1774+
estimator.fit(X_train, y_train, **fit_params_train)
1775+
avg_score.append(scorer(estimator, X_test, y_test, **score_params_test))
16851776
return np.mean(avg_score)
16861777

16871778

sklearn/model_selection/tests/test_validation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ def test_permutation_test_score_allow_nans():
862862
permutation_test_score(p, X, y)
863863

864864

865-
def test_permutation_test_score_fit_params():
865+
def test_permutation_test_score_params():
866866
X = np.arange(100).reshape(10, 10)
867867
y = np.array([0] * 5 + [1] * 5)
868868
clf = CheckingClassifier(expected_sample_weight=True)
@@ -873,8 +873,8 @@ def test_permutation_test_score_fit_params():
873873

874874
err_msg = r"sample_weight.shape == \(1,\), expected \(8,\)!"
875875
with pytest.raises(ValueError, match=err_msg):
876-
permutation_test_score(clf, X, y, fit_params={"sample_weight": np.ones(1)})
877-
permutation_test_score(clf, X, y, fit_params={"sample_weight": np.ones(10)})
876+
permutation_test_score(clf, X, y, params={"sample_weight": np.ones(1)})
877+
permutation_test_score(clf, X, y, params={"sample_weight": np.ones(10)})
878878

879879

880880
def test_cross_val_score_allow_nans():
@@ -2495,6 +2495,7 @@ def test_cross_validate_return_indices(global_random_seed):
24952495
(cross_val_score, {}),
24962496
(cross_val_predict, {}),
24972497
(learning_curve, {}),
2498+
(permutation_test_score, {}),
24982499
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
24992500
],
25002501
)
@@ -2526,6 +2527,7 @@ def test_fit_param_deprecation(func, extra_args):
25262527
(cross_val_score, {}),
25272528
(cross_val_predict, {}),
25282529
(learning_curve, {}),
2530+
(permutation_test_score, {}),
25292531
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
25302532
],
25312533
)
@@ -2551,6 +2553,7 @@ def test_groups_with_routing_validation(func, extra_args):
25512553
(cross_val_score, {}),
25522554
(cross_val_predict, {}),
25532555
(learning_curve, {}),
2556+
(permutation_test_score, {}),
25542557
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
25552558
],
25562559
)
@@ -2576,6 +2579,7 @@ def test_passed_unrequested_metadata(func, extra_args):
25762579
(cross_val_score, {}),
25772580
(cross_val_predict, {}),
25782581
(learning_curve, {}),
2582+
(permutation_test_score, {}),
25792583
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
25802584
],
25812585
)
@@ -2609,6 +2613,7 @@ def test_validation_functions_routing(func, extra_args):
26092613
cross_val_score: dict(scoring=scorer),
26102614
learning_curve: dict(scoring=scorer),
26112615
validation_curve: dict(scoring=scorer),
2616+
permutation_test_score: dict(scoring=scorer),
26122617
cross_val_predict: dict(),
26132618
}
26142619

0 commit comments

Comments
 (0)