Skip to content

Commit 20c7bd0

Browse files
Anurag-Varmalestevelucyleeow
authored
FIX Improve error message when RepeatedStratifiedKFold.split is called without a y argument (scikit-learn#29402)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Lucy Liu <jliu176@gmail.com>
1 parent 7eb7eff commit 20c7bd0

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

doc/whats_new/v1.6.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ Changelog
210210
estimator without re-fitting it.
211211
:pr:`29067` by :user:`Guillaume Lemaitre <glemaitre>`.
212212

213+
- |Fix| Improve error message when :func:`model_selection.RepeatedStratifiedKFold.split` is called without a `y` argument
214+
:pr:`29402` by :user:`Anurag Varma <Anurag-Varma>`.
215+
213216
:mod:`sklearn.neighbors`
214217
........................
215218

sklearn/model_selection/_split.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,6 +1769,43 @@ def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):
17691769
n_splits=n_splits,
17701770
)
17711771

1772+
def split(self, X, y, groups=None):
1773+
"""Generate indices to split data into training and test set.
1774+
1775+
Parameters
1776+
----------
1777+
X : array-like of shape (n_samples, n_features)
1778+
Training data, where `n_samples` is the number of samples
1779+
and `n_features` is the number of features.
1780+
1781+
Note that providing ``y`` is sufficient to generate the splits and
1782+
hence ``np.zeros(n_samples)`` may be used as a placeholder for
1783+
``X`` instead of actual training data.
1784+
1785+
y : array-like of shape (n_samples,)
1786+
The target variable for supervised learning problems.
1787+
Stratification is done based on the y labels.
1788+
1789+
groups : object
1790+
Always ignored, exists for compatibility.
1791+
1792+
Yields
1793+
------
1794+
train : ndarray
1795+
The training set indices for that split.
1796+
1797+
test : ndarray
1798+
The testing set indices for that split.
1799+
1800+
Notes
1801+
-----
1802+
Randomized CV splitters may return different results for each call of
1803+
split. You can make the results identical by setting `random_state`
1804+
to an integer.
1805+
"""
1806+
y = check_array(y, input_name="y", ensure_2d=False, dtype=None)
1807+
return super().split(X, y, groups=groups)
1808+
17721809

17731810
class BaseShuffleSplit(_MetadataRequester, metaclass=ABCMeta):
17741811
"""Base class for *ShuffleSplit.

sklearn/model_selection/tests/test_split.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@
8686

8787
ALL_SPLITTERS = NO_GROUP_SPLITTERS + GROUP_SPLITTERS # type: ignore
8888

89+
SPLITTERS_REQUIRING_TARGET = [
90+
StratifiedKFold(),
91+
StratifiedShuffleSplit(),
92+
RepeatedStratifiedKFold(),
93+
]
94+
8995
X = np.ones(10)
9096
y = np.arange(10) // 2
9197
test_groups = (
@@ -2054,3 +2060,12 @@ def test_no_group_splitters_warns_with_groups(cv):
20542060

20552061
with pytest.warns(UserWarning, match=msg):
20562062
cv.split(X, y, groups=groups)
2063+
2064+
2065+
@pytest.mark.parametrize(
2066+
"cv", SPLITTERS_REQUIRING_TARGET, ids=[str(cv) for cv in SPLITTERS_REQUIRING_TARGET]
2067+
)
2068+
def test_stratified_splitter_without_y(cv):
2069+
msg = "missing 1 required positional argument: 'y'"
2070+
with pytest.raises(TypeError, match=msg):
2071+
cv.split(X)

0 commit comments

Comments
 (0)