Skip to content

Commit d9212de

Browse files
authored
MAINT Param validation: better message when common test fails to raise (scikit-learn#26702)
1 parent 9ab298a commit d9212de

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,16 @@ def _check_function_param_validation(
9292
rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead."
9393
)
9494

95+
err_msg = (
96+
f"{func_name} does not raise an informative error message when the "
97+
f"parameter {param_name} does not have a valid type. If any Python type "
98+
"is valid, the constraint should be 'no_validation'."
99+
)
100+
95101
# First, check that the error is raised if param doesn't match any valid type.
96102
with pytest.raises(InvalidParameterError, match=match):
97103
func(**{**valid_required_params, param_name: param_with_bad_type})
104+
pytest.fail(err_msg)
98105

99106
# Then, for constraints that are more than a type constraint, check that the
100107
# error is raised if param does match a valid type but does not match any valid
@@ -107,8 +114,19 @@ def _check_function_param_validation(
107114
except NotImplementedError:
108115
continue
109116

117+
err_msg = (
118+
f"{func_name} does not raise an informative error message when the "
119+
f"parameter {param_name} does not have a valid value.\n"
120+
"Constraints should be disjoint. For instance "
121+
"[StrOptions({'a_string'}), str] is not a acceptable set of "
122+
"constraint because generating an invalid string for the first "
123+
"constraint will always produce a valid string for the second "
124+
"constraint."
125+
)
126+
110127
with pytest.raises(InvalidParameterError, match=match):
111128
func(**{**valid_required_params, param_name: bad_value})
129+
pytest.fail(err_msg)
112130

113131

114132
PARAM_VALIDATION_FUNCTION_LIST = [

sklearn/utils/estimator_checks.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4287,6 +4287,12 @@ def check_param_validation(name, estimator_orig):
42874287
# the method is not accessible with the current set of parameters
42884288
continue
42894289

4290+
err_msg = (
4291+
f"{name} does not raise an informative error message when the parameter"
4292+
f" {param_name} does not have a valid type. If any Python type is"
4293+
" valid, the constraint should be 'no_validation'."
4294+
)
4295+
42904296
with raises(InvalidParameterError, match=match, err_msg=err_msg):
42914297
if any(
42924298
isinstance(X_type, str) and X_type.endswith("labels")
@@ -4315,6 +4321,16 @@ def check_param_validation(name, estimator_orig):
43154321
# the method is not accessible with the current set of parameters
43164322
continue
43174323

4324+
err_msg = (
4325+
f"{name} does not raise an informative error message when the "
4326+
f"parameter {param_name} does not have a valid value.\n"
4327+
"Constraints should be disjoint. For instance "
4328+
"[StrOptions({'a_string'}), str] is not a acceptable set of "
4329+
"constraint because generating an invalid string for the first "
4330+
"constraint will always produce a valid string for the second "
4331+
"constraint."
4332+
)
4333+
43184334
with raises(InvalidParameterError, match=match, err_msg=err_msg):
43194335
if any(
43204336
X_type.endswith("labels")

0 commit comments

Comments
 (0)