Skip to content

Commit 754109c

Browse files
tkamishimaTomDLT
authored andcommitted
[MRG+1] enable to use get_n_splits of LeaveOneGroupOut and LeavePGroupsOut with dummy parameters (scikit-learn#8794)
* remove needless argument checking * add parameter checking as in LeavePGroupsOut * add examples with dummy inputs * add unittest for a get_n_splits method in LeaveOneGroupOut and LeavePGroupsOut classes * X and y can be ommited in a get_n_splits function. * fix error messages * update examples * fix test for an error message * Revert "fix test for an error message" This reverts commit 68b9842. * fix test for an error message * fix error messages * remove tailing white spaces * add periods to messages * test for ValueError’s of get_n_splits methods of LeaveOneOut / LeavePOut classes * fix documents: * parameter name: group -> groups * modfy white space
1 parent 6413feb commit 754109c

File tree

4 files changed

+53
-22
lines changed

4 files changed

+53
-22
lines changed

sklearn/model_selection/_split.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def get_n_splits(self, X, y=None, groups=None):
188188
Returns the number of splitting iterations in the cross-validator.
189189
"""
190190
if X is None:
191-
raise ValueError("The X parameter should not be None")
191+
raise ValueError("The 'X' parameter should not be None.")
192192
return _num_samples(X)
193193

194194

@@ -259,7 +259,7 @@ def get_n_splits(self, X, y=None, groups=None):
259259
Always ignored, exists for compatibility.
260260
"""
261261
if X is None:
262-
raise ValueError("The X parameter should not be None")
262+
raise ValueError("The 'X' parameter should not be None.")
263263
return int(comb(_num_samples(X), self.p, exact=True))
264264

265265

@@ -477,7 +477,7 @@ def __init__(self, n_splits=3):
477477

478478
def _iter_test_indices(self, X, y, groups):
479479
if groups is None:
480-
raise ValueError("The groups parameter should not be None")
480+
raise ValueError("The 'groups' parameter should not be None.")
481481
groups = check_array(groups, ensure_2d=False, dtype=None)
482482

483483
unique_groups, groups = np.unique(groups, return_inverse=True)
@@ -765,6 +765,8 @@ class LeaveOneGroupOut(BaseCrossValidator):
765765
>>> logo = LeaveOneGroupOut()
766766
>>> logo.get_n_splits(X, y, groups)
767767
2
768+
>>> logo.get_n_splits(groups=groups) # 'groups' is always required
769+
2
768770
>>> print(logo)
769771
LeaveOneGroupOut()
770772
>>> for train_index, test_index in logo.split(X, y, groups):
@@ -785,7 +787,7 @@ class LeaveOneGroupOut(BaseCrossValidator):
785787

786788
def _iter_test_masks(self, X, y, groups):
787789
if groups is None:
788-
raise ValueError("The groups parameter should not be None")
790+
raise ValueError("The 'groups' parameter should not be None.")
789791
# We make a copy of groups to avoid side-effects during iteration
790792
groups = check_array(groups, copy=True, ensure_2d=False, dtype=None)
791793
unique_groups = np.unique(groups)
@@ -796,28 +798,31 @@ def _iter_test_masks(self, X, y, groups):
796798
for i in unique_groups:
797799
yield groups == i
798800

799-
def get_n_splits(self, X, y, groups):
801+
def get_n_splits(self, X=None, y=None, groups=None):
800802
"""Returns the number of splitting iterations in the cross-validator
801803
802804
Parameters
803805
----------
804-
X : object
806+
X : object, optional
805807
Always ignored, exists for compatibility.
806808
807-
y : object
809+
y : object, optional
808810
Always ignored, exists for compatibility.
809811
810812
groups : array-like, with shape (n_samples,), optional
811813
Group labels for the samples used while splitting the dataset into
812-
train/test set.
814+
train/test set. This 'groups' parameter must always be specified to
815+
calculate the number of splits, though the other parameters can be
816+
omitted.
813817
814818
Returns
815819
-------
816820
n_splits : int
817821
Returns the number of splitting iterations in the cross-validator.
818822
"""
819823
if groups is None:
820-
raise ValueError("The groups parameter should not be None")
824+
raise ValueError("The 'groups' parameter should not be None.")
825+
groups = check_array(groups, ensure_2d=False, dtype=None)
821826
return len(np.unique(groups))
822827

823828

@@ -852,6 +857,8 @@ class LeavePGroupsOut(BaseCrossValidator):
852857
>>> lpgo = LeavePGroupsOut(n_groups=2)
853858
>>> lpgo.get_n_splits(X, y, groups)
854859
3
860+
>>> lpgo.get_n_splits(groups=groups) # 'groups' is always required
861+
3
855862
>>> print(lpgo)
856863
LeavePGroupsOut(n_groups=2)
857864
>>> for train_index, test_index in lpgo.split(X, y, groups):
@@ -879,7 +886,7 @@ def __init__(self, n_groups):
879886

880887
def _iter_test_masks(self, X, y, groups):
881888
if groups is None:
882-
raise ValueError("The groups parameter should not be None")
889+
raise ValueError("The 'groups' parameter should not be None.")
883890
groups = check_array(groups, copy=True, ensure_2d=False, dtype=None)
884891
unique_groups = np.unique(groups)
885892
if self.n_groups >= len(unique_groups):
@@ -895,32 +902,31 @@ def _iter_test_masks(self, X, y, groups):
895902
test_index[groups == l] = True
896903
yield test_index
897904

898-
def get_n_splits(self, X, y, groups):
905+
def get_n_splits(self, X=None, y=None, groups=None):
899906
"""Returns the number of splitting iterations in the cross-validator
900907
901908
Parameters
902909
----------
903-
X : object
910+
X : object, optional
904911
Always ignored, exists for compatibility.
905-
``np.zeros(n_samples)`` may be used as a placeholder.
906912
907-
y : object
913+
y : object, optional
908914
Always ignored, exists for compatibility.
909-
``np.zeros(n_samples)`` may be used as a placeholder.
910915
911916
groups : array-like, with shape (n_samples,), optional
912917
Group labels for the samples used while splitting the dataset into
913-
train/test set.
918+
train/test set. This 'groups' parameter must always be specified to
919+
calculate the number of splits, though the other parameters can be
920+
omitted.
914921
915922
Returns
916923
-------
917924
n_splits : int
918925
Returns the number of splitting iterations in the cross-validator.
919926
"""
920927
if groups is None:
921-
raise ValueError("The groups parameter should not be None")
928+
raise ValueError("The 'groups' parameter should not be None.")
922929
groups = check_array(groups, ensure_2d=False, dtype=None)
923-
X, y, groups = indexable(X, y, groups)
924930
return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
925931

926932

@@ -1318,7 +1324,7 @@ def __init__(self, n_splits=5, test_size=0.2, train_size=None,
13181324

13191325
def _iter_indices(self, X, y, groups):
13201326
if groups is None:
1321-
raise ValueError("The groups parameter should not be None")
1327+
raise ValueError("The 'groups' parameter should not be None.")
13221328
groups = check_array(groups, ensure_2d=False, dtype=None)
13231329
classes, group_indices = np.unique(groups, return_inverse=True)
13241330
for group_train, group_test in super(

sklearn/model_selection/tests/test_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def test_grid_search_groups():
317317
for cv in group_cvs:
318318
gs = GridSearchCV(clf, grid, cv=cv)
319319
assert_raise_message(ValueError,
320-
"The groups parameter should not be None",
320+
"The 'groups' parameter should not be None.",
321321
gs.fit, X, y)
322322
gs.fit(X, y, groups=groups)
323323

sklearn/model_selection/tests/test_split.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,13 @@ def test_cross_validator_with_default_params():
189189
# Test if the repr works without any errors
190190
assert_equal(cv_repr, repr(cv))
191191

192+
# ValueError for get_n_splits methods
193+
msg = "The 'X' parameter should not be None."
194+
assert_raise_message(ValueError, msg,
195+
loo.get_n_splits, None, y, groups)
196+
assert_raise_message(ValueError, msg,
197+
lpo.get_n_splits, None, y, groups)
198+
192199

193200
def check_valid_split(train, test, n_samples=None):
194201
# Use python sets to get more informative assertion failure messages
@@ -757,6 +764,24 @@ def test_leave_one_p_group_out():
757764
# The number of groups in test must be equal to p_groups_out
758765
assert_true(np.unique(groups_arr[test]).shape[0], p_groups_out)
759766

767+
# check get_n_splits() with dummy parameters
768+
assert_equal(logo.get_n_splits(None, None, ['a', 'b', 'c', 'b', 'c']), 3)
769+
assert_equal(logo.get_n_splits(groups=[1.0, 1.1, 1.0, 1.2]), 3)
770+
assert_equal(lpgo_2.get_n_splits(None, None, np.arange(4)), 6)
771+
assert_equal(lpgo_1.get_n_splits(groups=np.arange(4)), 4)
772+
773+
# raise ValueError if a `groups` parameter is illegal
774+
with assert_raises(ValueError):
775+
logo.get_n_splits(None, None, [0.0, np.nan, 0.0])
776+
with assert_raises(ValueError):
777+
lpgo_2.get_n_splits(None, None, [0.0, np.inf, 0.0])
778+
779+
msg = "The 'groups' parameter should not be None."
780+
assert_raise_message(ValueError, msg,
781+
logo.get_n_splits, None, None, None)
782+
assert_raise_message(ValueError, msg,
783+
lpgo_1.get_n_splits, None, None, None)
784+
760785

761786
def test_leave_group_out_changing_groups():
762787
# Check that LeaveOneGroupOut and LeavePGroupsOut work normally if

sklearn/model_selection/tests/test_validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,10 @@ def test_cross_val_score_predict_groups():
259259
GroupShuffleSplit()]
260260
for cv in group_cvs:
261261
assert_raise_message(ValueError,
262-
"The groups parameter should not be None",
262+
"The 'groups' parameter should not be None.",
263263
cross_val_score, estimator=clf, X=X, y=y, cv=cv)
264264
assert_raise_message(ValueError,
265-
"The groups parameter should not be None",
265+
"The 'groups' parameter should not be None.",
266266
cross_val_predict, estimator=clf, X=X, y=y, cv=cv)
267267

268268

0 commit comments

Comments
 (0)