Skip to content

Commit b578371

Browse files
AishwaryaRKjnothman
authored andcommitted
[MRG] Fixes scikit-learn#8736 add get_n_splits for RepeatedKFold and RepeatedStratifiedKFold (scikit-learn#8802)
1 parent ee82c3f commit b578371

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

sklearn/model_selection/_split.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,33 @@ def split(self, X, y=None, groups=None):
997997
for train_index, test_index in cv.split(X, y, groups):
998998
yield train_index, test_index
999999

1000+
def get_n_splits(self, X=None, y=None, groups=None):
1001+
"""Returns the number of splitting iterations in the cross-validator
1002+
1003+
Parameters
1004+
----------
1005+
X : object
1006+
Always ignored, exists for compatibility.
1007+
``np.zeros(n_samples)`` may be used as a placeholder.
1008+
1009+
y : object
1010+
Always ignored, exists for compatibility.
1011+
``np.zeros(n_samples)`` may be used as a placeholder.
1012+
1013+
groups : array-like, with shape (n_samples,), optional
1014+
Group labels for the samples used while splitting the dataset into
1015+
train/test set.
1016+
1017+
Returns
1018+
-------
1019+
n_splits : int
1020+
Returns the number of splitting iterations in the cross-validator.
1021+
"""
1022+
rng = check_random_state(self.random_state)
1023+
cv = self.cv(random_state=rng, shuffle=True,
1024+
**self.cvargs)
1025+
return cv.get_n_splits(X, y, groups) * self.n_repeats
1026+
10001027

10011028
class RepeatedKFold(_RepeatedSplits):
10021029
"""Repeated K-Fold cross validator.

sklearn/model_selection/tests/test_split.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,22 @@ def test_repeated_kfold_determinstic_split():
844844
assert_raises(StopIteration, next, splits)
845845

846846

847+
def test_get_n_splits_for_repeated_kfold():
848+
n_splits = 3
849+
n_repeats = 4
850+
rkf = RepeatedKFold(n_splits, n_repeats)
851+
expected_n_splits = n_splits * n_repeats
852+
assert_equal(expected_n_splits, rkf.get_n_splits())
853+
854+
855+
def test_get_n_splits_for_repeated_stratified_kfold():
856+
n_splits = 3
857+
n_repeats = 4
858+
rskf = RepeatedStratifiedKFold(n_splits, n_repeats)
859+
expected_n_splits = n_splits * n_repeats
860+
assert_equal(expected_n_splits, rskf.get_n_splits())
861+
862+
847863
def test_repeated_stratified_kfold_determinstic_split():
848864
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
849865
y = [1, 1, 1, 0, 0]

0 commit comments

Comments
 (0)