Skip to content

Commit e94c35d

Browse files
ogriseljnothman
authored andcommitted
ENH make validation message easier to understand (scikit-learn#7294)
1 parent a4f632f commit e94c35d

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

sklearn/metrics/tests/test_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,8 +1394,8 @@ def test_log_loss():
13941394
assert_raise_message(ValueError, error_str, log_loss, y_true, y_pred)
13951395

13961396
y_pred = [[0.2, 0.7], [0.6, 0.5], [0.2, 0.3]]
1397-
error_str = ('Found arrays with inconsistent numbers of '
1398-
'samples: [2 3]')
1397+
error_str = ('Found input variables with inconsistent numbers of samples: '
1398+
'[3, 2]')
13991399
assert_raise_message(ValueError, error_str, log_loss, y_true, y_pred)
14001400

14011401
# works when the labels argument is used

sklearn/utils/validation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,11 @@ def check_consistent_length(*arrays):
174174
Objects that will be checked for consistent length.
175175
"""
176176

177-
uniques = np.unique([_num_samples(X) for X in arrays if X is not None])
177+
lengths = [_num_samples(X) for X in arrays if X is not None]
178+
uniques = np.unique(lengths)
178179
if len(uniques) > 1:
179-
raise ValueError("Found arrays with inconsistent numbers of samples: "
180-
"%s" % str(uniques))
180+
raise ValueError("Found input variables with inconsistent numbers of"
181+
" samples: %r" % [int(l) for l in lengths])
181182

182183

183184
def indexable(*iterables):

0 commit comments

Comments
 (0)