Skip to content

Commit 4985e69

Browse files
rprkhglemaitreadrinjalalibetatim
authored
ENH check_classification_targets raises a warning when unique classes > 50% of n_samples (scikit-learn#26335)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: adrinjalali <adrin.jalali@gmail.com> Co-authored-by: Tim Head <betatim@gmail.com>
1 parent d042d68 commit 4985e69

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- |Enhancement| :func:`utils.multiclass.type_of_target` raises a warning when the number
2+
of unique classes is greater than 50% of the number of samples. This warning is raised
3+
only if `y` has more than 20 samples.
4+
By :user:`Rahil Parikh <rprkh>`.

sklearn/utils/multiclass.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,16 @@ def _raise_or_return():
413413
# Check multiclass
414414
if issparse(first_row_or_val):
415415
first_row_or_val = first_row_or_val.data
416-
if cached_unique(y).shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1):
416+
classes = cached_unique(y)
417+
if y.shape[0] > 20 and classes.shape[0] > round(0.5 * y.shape[0]):
418+
# Only raise the warning when we have at least 20 samples.
419+
warnings.warn(
420+
"The number of unique classes is greater than 50% of the number "
421+
"of samples.",
422+
UserWarning,
423+
stacklevel=2,
424+
)
425+
if classes.shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1):
417426
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
418427
return "multiclass" + suffix
419428
else:

sklearn/utils/tests/test_multiclass.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from itertools import product
23

34
import numpy as np
@@ -294,6 +295,25 @@ def test_unique_labels():
294295
assert_array_equal(unique_labels(np.ones((4, 5)), np.ones((5, 5))), np.arange(5))
295296

296297

298+
def test_type_of_target_too_many_unique_classes():
299+
"""Check that we raise a warning when the number of unique classes is greater than
300+
50% of the number of samples.
301+
302+
We need to check that we don't raise if we have less than 20 samples.
303+
"""
304+
305+
y = np.arange(25)
306+
msg = r"The number of unique classes is greater than 50% of the number of samples."
307+
with pytest.warns(UserWarning, match=msg):
308+
type_of_target(y)
309+
310+
# less than 20 samples, no warning should be raised
311+
y = np.arange(10)
312+
with warnings.catch_warnings():
313+
warnings.simplefilter("error")
314+
type_of_target(y)
315+
316+
297317
def test_unique_labels_non_specific():
298318
# Test unique_labels with a variety of collected examples
299319

0 commit comments

Comments
 (0)