@@ -1493,6 +1493,7 @@ def _check_is_permutation(indices, n_samples):
1493
1493
"verbose" : ["verbose" ],
1494
1494
"scoring" : [StrOptions (set (get_scorer_names ())), callable , None ],
1495
1495
"fit_params" : [dict , None ],
1496
+ "params" : [dict , None ],
1496
1497
},
1497
1498
prefer_skip_nested_validation = False , # estimator is not validated yet
1498
1499
)
@@ -1509,6 +1510,7 @@ def permutation_test_score(
1509
1510
verbose = 0 ,
1510
1511
scoring = None ,
1511
1512
fit_params = None ,
1513
+ params = None ,
1512
1514
):
1513
1515
"""Evaluate the significance of a cross-validated score with permutations.
1514
1516
@@ -1548,6 +1550,13 @@ def permutation_test_score(
1548
1550
cross-validator uses them for grouping the samples while splitting
1549
1551
the dataset into train/test set.
1550
1552
1553
+ .. versionchanged:: 1.6
1554
+ ``groups`` can only be passed if metadata routing is not enabled
1555
+ via ``sklearn.set_config(enable_metadata_routing=True)``. When routing
1556
+ is enabled, pass ``groups`` alongside other metadata via the ``params``
1557
+ argument instead. E.g.:
1558
+ ``permutation_test_score(..., params={'groups': groups})``.
1559
+
1551
1560
cv : int, cross-validation generator or an iterable, default=None
1552
1561
Determines the cross-validation splitting strategy.
1553
1562
Possible inputs for cv are:
@@ -1594,7 +1603,24 @@ def permutation_test_score(
1594
1603
fit_params : dict, default=None
1595
1604
Parameters to pass to the fit method of the estimator.
1596
1605
1597
- .. versionadded:: 0.24
1606
+ .. deprecated:: 1.6
1607
+ This parameter is deprecated and will be removed in version 1.6. Use
1608
+ ``params`` instead.
1609
+
1610
+ params : dict, default=None
1611
+ Parameters to pass to the `fit` method of the estimator, the scorer
1612
+ and the cv splitter.
1613
+
1614
+ - If `enable_metadata_routing=False` (default):
1615
+ Parameters directly passed to the `fit` method of the estimator.
1616
+
1617
+ - If `enable_metadata_routing=True`:
1618
+ Parameters safely routed to the `fit` method of the estimator,
1619
+ `cv` object and `scorer`.
1620
+ See :ref:`Metadata Routing User Guide <metadata_routing>` for more
1621
+ details.
1622
+
1623
+ .. versionadded:: 1.6
1598
1624
1599
1625
Returns
1600
1626
-------
@@ -1643,26 +1669,86 @@ def permutation_test_score(
1643
1669
>>> print(f"P-value: {pvalue:.3f}")
1644
1670
P-value: 0.010
1645
1671
"""
1672
+ params = _check_params_groups_deprecation (fit_params , params , groups , "1.8" )
1673
+
1646
1674
X , y , groups = indexable (X , y , groups )
1647
1675
1648
1676
cv = check_cv (cv , y , classifier = is_classifier (estimator ))
1649
1677
scorer = check_scoring (estimator , scoring = scoring )
1650
1678
random_state = check_random_state (random_state )
1651
1679
1680
+ if _routing_enabled ():
1681
+ router = (
1682
+ MetadataRouter (owner = "permutation_test_score" )
1683
+ .add (
1684
+ estimator = estimator ,
1685
+ # TODO(SLEP6): also pass metadata to the predict method for
1686
+ # scoring?
1687
+ method_mapping = MethodMapping ().add (caller = "fit" , callee = "fit" ),
1688
+ )
1689
+ .add (
1690
+ splitter = cv ,
1691
+ method_mapping = MethodMapping ().add (caller = "fit" , callee = "split" ),
1692
+ )
1693
+ .add (
1694
+ scorer = scorer ,
1695
+ method_mapping = MethodMapping ().add (caller = "fit" , callee = "score" ),
1696
+ )
1697
+ )
1698
+
1699
+ try :
1700
+ routed_params = process_routing (router , "fit" , ** params )
1701
+ except UnsetMetadataPassedError as e :
1702
+ # The default exception would mention `fit` since in the above
1703
+ # `process_routing` code, we pass `fit` as the caller. However,
1704
+ # the user is not calling `fit` directly, so we change the message
1705
+ # to make it more suitable for this case.
1706
+ unrequested_params = sorted (e .unrequested_params )
1707
+ raise UnsetMetadataPassedError (
1708
+ message = (
1709
+ f"{ unrequested_params } are passed to `permutation_test_score`"
1710
+ " but are not explicitly set as requested or not requested"
1711
+ " for permutation_test_score's"
1712
+ f" estimator: { estimator .__class__ .__name__ } . Call"
1713
+ " `.set_fit_request({{metadata}}=True)` on the estimator for"
1714
+ f" each metadata in { unrequested_params } that you"
1715
+ " want to use and `metadata=False` for not using it. See the"
1716
+ " Metadata Routing User guide"
1717
+ " <https://scikit-learn.org/stable/metadata_routing.html> for more"
1718
+ " information."
1719
+ ),
1720
+ unrequested_params = e .unrequested_params ,
1721
+ routed_params = e .routed_params ,
1722
+ )
1723
+
1724
+ else :
1725
+ routed_params = Bunch ()
1726
+ routed_params .estimator = Bunch (fit = params )
1727
+ routed_params .splitter = Bunch (split = {"groups" : groups })
1728
+ routed_params .scorer = Bunch (score = {})
1729
+
1652
1730
# We clone the estimator to make sure that all the folds are
1653
1731
# independent, and that it is pickle-able.
1654
1732
score = _permutation_test_score (
1655
- clone (estimator ), X , y , groups , cv , scorer , fit_params = fit_params
1733
+ clone (estimator ),
1734
+ X ,
1735
+ y ,
1736
+ cv ,
1737
+ scorer ,
1738
+ split_params = routed_params .splitter .split ,
1739
+ fit_params = routed_params .estimator .fit ,
1740
+ score_params = routed_params .scorer .score ,
1656
1741
)
1657
1742
permutation_scores = Parallel (n_jobs = n_jobs , verbose = verbose )(
1658
1743
delayed (_permutation_test_score )(
1659
1744
clone (estimator ),
1660
1745
X ,
1661
1746
_shuffle (y , groups , random_state ),
1662
- groups ,
1663
1747
cv ,
1664
1748
scorer ,
1665
- fit_params = fit_params ,
1749
+ split_params = routed_params .splitter .split ,
1750
+ fit_params = routed_params .estimator .fit ,
1751
+ score_params = routed_params .scorer .score ,
1666
1752
)
1667
1753
for _ in range (n_permutations )
1668
1754
)
@@ -1671,17 +1757,22 @@ def permutation_test_score(
1671
1757
return score , permutation_scores , pvalue
1672
1758
1673
1759
1674
- def _permutation_test_score (estimator , X , y , groups , cv , scorer , fit_params ):
1760
+ def _permutation_test_score (
1761
+ estimator , X , y , cv , scorer , split_params , fit_params , score_params
1762
+ ):
1675
1763
"""Auxiliary function for permutation_test_score"""
1676
1764
# Adjust length of sample weights
1677
1765
fit_params = fit_params if fit_params is not None else {}
1766
+ score_params = score_params if score_params is not None else {}
1767
+
1678
1768
avg_score = []
1679
- for train , test in cv .split (X , y , groups ):
1769
+ for train , test in cv .split (X , y , ** split_params ):
1680
1770
X_train , y_train = _safe_split (estimator , X , y , train )
1681
1771
X_test , y_test = _safe_split (estimator , X , y , test , train )
1682
- fit_params = _check_method_params (X , params = fit_params , indices = train )
1683
- estimator .fit (X_train , y_train , ** fit_params )
1684
- avg_score .append (scorer (estimator , X_test , y_test ))
1772
+ fit_params_train = _check_method_params (X , params = fit_params , indices = train )
1773
+ score_params_test = _check_method_params (X , params = score_params , indices = test )
1774
+ estimator .fit (X_train , y_train , ** fit_params_train )
1775
+ avg_score .append (scorer (estimator , X_test , y_test , ** score_params_test ))
1685
1776
return np .mean (avg_score )
1686
1777
1687
1778
0 commit comments