Skip to content

Commit 6092846

Browse files
ENH Improve error message in check_requires_y_none (scikit-learn#31481)
1 parent 082eb5d commit 6092846

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4405,7 +4405,14 @@ def check_requires_y_none(name, estimator_orig):
44054405
estimator.fit(X, None)
44064406
except ValueError as ve:
44074407
if not any(msg in str(ve) for msg in expected_err_msgs):
4408-
raise ve
4408+
raise ValueError(
4409+
"Your estimator raised a ValueError, but with the incorrect or "
4410+
"incomplete error message to be considered a graceful fail. "
4411+
"The expected message in the ValueError should contain one of "
4412+
f"these literal strings:\n{expected_err_msgs}. "
4413+
f"For example, you could have `ValueError('{expected_err_msgs[0]}')`.\n"
4414+
f"This is the error message in your exception:\n{ve}"
4415+
)
44094416

44104417

44114418
@ignore_warnings(category=FutureWarning)

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,27 @@ def fit(self, X, y):
14371437
# no warnings are raised
14381438
assert not [r.message for r in record]
14391439

1440+
# Make an estimator that throws the wrong error to make sure we catch it
1441+
class EstimatorWithWrongError(BaseEstimator):
1442+
def fit(self, X, y):
1443+
try:
1444+
X, y = check_X_y(X, y)
1445+
except ValueError as ve:
1446+
# This assertion is just to make sure we are catching the value error
1447+
# that comes from wrong y (=None) and not some other value error
1448+
assert str(ve) == (
1449+
"estimator requires y to be passed, but the target y is None"
1450+
)
1451+
# Override the error message force fail
1452+
raise ValueError("This is the wrong message that raises error")
1453+
1454+
err_msg = (
1455+
"Your estimator raised a ValueError, but with the incorrect or "
1456+
"incomplete error message to be considered a graceful fail."
1457+
)
1458+
with raises(ValueError, match=err_msg):
1459+
check_requires_y_none("estimator", EstimatorWithWrongError())
1460+
14401461

14411462
def test_non_deterministic_estimator_skip_tests():
14421463
# check estimators with non_deterministic tag set to True

0 commit comments

Comments
 (0)