Skip to content

Commit 215be2e

Browse files
authored
MAINT Clean-up some testing utils (scikit-learn#29528)
1 parent c55d064 commit 215be2e

File tree

5 files changed

+5
-113
lines changed

5 files changed

+5
-113
lines changed

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,7 @@ def test_lasso_cv():
264264
)
265265
# check that they also give a similar MSE
266266
mse_lars = interpolate.interp1d(lars.cv_alphas_, lars.mse_path_.T)
267-
np.testing.assert_approx_equal(
268-
mse_lars(clf.alphas_[5]).mean(), clf.mse_path_[5].mean(), significant=2
269-
)
267+
assert_allclose(mse_lars(clf.alphas_[5]).mean(), clf.mse_path_[5].mean(), rtol=1e-2)
270268

271269
# test set
272270
assert clf.score(X_test, y_test) > 0.99

sklearn/tests/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ def test_set_params():
326326

327327
# we don't currently catch if the things in pipeline are estimators
328328
# bad_pipeline = Pipeline([("bad", NoEstimator())])
329-
# assert_raises(AttributeError, bad_pipeline.set_params,
330-
# bad__stupid_param=True)
329+
# with pytest.raises(AttributeError):
330+
# bad_pipeline.set_params(bad__stupid_param=True)
331331

332332

333333
def test_set_params_passes_all_parameters():

sklearn/utils/_testing.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
from functools import wraps
2121
from inspect import signature
2222
from subprocess import STDOUT, CalledProcessError, TimeoutExpired, check_output
23-
from unittest import TestCase
2423

2524
import joblib
2625
import numpy as np
2726
import scipy as sp
2827
from numpy.testing import assert_allclose as np_assert_allclose
2928
from numpy.testing import (
3029
assert_almost_equal,
31-
assert_approx_equal,
3230
assert_array_almost_equal,
3331
assert_array_equal,
3432
assert_array_less,
@@ -51,28 +49,16 @@
5149
)
5250

5351
__all__ = [
54-
"assert_raises",
55-
"assert_raises_regexp",
5652
"assert_array_equal",
5753
"assert_almost_equal",
5854
"assert_array_almost_equal",
5955
"assert_array_less",
60-
"assert_approx_equal",
6156
"assert_allclose",
6257
"assert_run_python_script_without_output",
6358
"SkipTest",
6459
]
6560

66-
_dummy = TestCase("__init__")
67-
assert_raises = _dummy.assertRaises
6861
SkipTest = unittest.case.SkipTest
69-
assert_dict_equal = _dummy.assertDictEqual
70-
71-
assert_raises_regex = _dummy.assertRaisesRegex
72-
# assert_raises_regexp is deprecated in Python 3.4 in favor of
73-
# assert_raises_regex but lets keep the backward compat in scikit-learn with
74-
# the old name for now
75-
assert_raises_regexp = assert_raises_regex
7662

7763

7864
def ignore_warnings(obj=None, category=Warning):
@@ -176,48 +162,6 @@ def __exit__(self, *exc_info):
176162
self.log[:] = []
177163

178164

179-
def assert_raise_message(exceptions, message, function, *args, **kwargs):
180-
"""Helper function to test the message raised in an exception.
181-
182-
Given an exception, a callable to raise the exception, and
183-
a message string, tests that the correct exception is raised and
184-
that the message is a substring of the error thrown. Used to test
185-
that the specific message thrown during an exception is correct.
186-
187-
Parameters
188-
----------
189-
exceptions : exception or tuple of exception
190-
An Exception object.
191-
192-
message : str
193-
The error message or a substring of the error message.
194-
195-
function : callable
196-
Callable object to raise error.
197-
198-
*args : the positional arguments to `function`.
199-
200-
**kwargs : the keyword arguments to `function`.
201-
"""
202-
try:
203-
function(*args, **kwargs)
204-
except exceptions as e:
205-
error_message = str(e)
206-
if message not in error_message:
207-
raise AssertionError(
208-
"Error message does not include the expected"
209-
" string: %r. Observed error message: %r" % (message, error_message)
210-
)
211-
else:
212-
# concatenate exception names
213-
if isinstance(exceptions, tuple):
214-
names = " or ".join(e.__name__ for e in exceptions)
215-
else:
216-
names = exceptions.__name__
217-
218-
raise AssertionError("%s not raised by %s" % (names, function.__name__))
219-
220-
221165
def assert_allclose(
222166
actual, desired, rtol=None, atol=0.0, equal_nan=True, err_msg="", verbose=True
223167
):

sklearn/utils/estimator_checks.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
assert_array_almost_equal,
7676
assert_array_equal,
7777
assert_array_less,
78-
assert_raise_message,
7978
create_memmap_backed_data,
8079
ignore_warnings,
8180
raises,
@@ -1489,9 +1488,8 @@ def check_fit2d_predict1d(name, estimator_orig):
14891488

14901489
for method in ["predict", "transform", "decision_function", "predict_proba"]:
14911490
if hasattr(estimator, method):
1492-
assert_raise_message(
1493-
ValueError, "Reshape your data", getattr(estimator, method), X[0]
1494-
)
1491+
with raises(ValueError, match="Reshape your data"):
1492+
getattr(estimator, method)(X[0])
14951493

14961494

14971495
def _apply_on_subsets(func, X):

sklearn/utils/tests/test_testing.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
_get_warnings_filters_info_list,
1616
assert_allclose,
1717
assert_allclose_dense_sparse,
18-
assert_raise_message,
19-
assert_raises,
20-
assert_raises_regex,
2118
assert_run_python_script_without_output,
2219
check_docstring_parameters,
2320
create_memmap_backed_data,
@@ -66,51 +63,6 @@ def test_assert_allclose_dense_sparse(csr_container):
6663
assert_allclose_dense_sparse(B, A)
6764

6865

69-
def test_assert_raises_msg():
70-
with assert_raises_regex(AssertionError, "Hello world"):
71-
with assert_raises(ValueError, msg="Hello world"):
72-
pass
73-
74-
75-
def test_assert_raise_message():
76-
def _raise_ValueError(message):
77-
raise ValueError(message)
78-
79-
def _no_raise():
80-
pass
81-
82-
assert_raise_message(ValueError, "test", _raise_ValueError, "test")
83-
84-
assert_raises(
85-
AssertionError,
86-
assert_raise_message,
87-
ValueError,
88-
"something else",
89-
_raise_ValueError,
90-
"test",
91-
)
92-
93-
assert_raises(
94-
ValueError,
95-
assert_raise_message,
96-
TypeError,
97-
"something else",
98-
_raise_ValueError,
99-
"test",
100-
)
101-
102-
assert_raises(AssertionError, assert_raise_message, ValueError, "test", _no_raise)
103-
104-
# multiple exceptions in a tuple
105-
assert_raises(
106-
AssertionError,
107-
assert_raise_message,
108-
(ValueError, AttributeError),
109-
"test",
110-
_no_raise,
111-
)
112-
113-
11466
def test_ignore_warning():
11567
# This check that ignore_warning decorator and context manager are working
11668
# as expected

0 commit comments

Comments
 (0)