Skip to content

Commit 3d5e243

Browse files
authored
MAINT Remove assert_no_warnings from tests (scikit-learn#29525)
1 parent 70a84ea commit 3d5e243

File tree

7 files changed

+177
-166
lines changed

7 files changed

+177
-166
lines changed

sklearn/compose/tests/test_target.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.linear_model import LinearRegression, OrthogonalMatchingPursuit
1111
from sklearn.pipeline import Pipeline
1212
from sklearn.preprocessing import FunctionTransformer, StandardScaler
13-
from sklearn.utils._testing import assert_allclose, assert_no_warnings
13+
from sklearn.utils._testing import assert_allclose
1414

1515
friedman = datasets.make_friedman1(random_state=0)
1616

@@ -66,17 +66,17 @@ def test_transform_target_regressor_invertible():
6666
)
6767
with pytest.warns(
6868
UserWarning,
69-
match=(
70-
"The provided functions or"
71-
" transformer are not strictly inverse of each other."
72-
),
69+
match=(r"The provided functions.* are not strictly inverse of each other"),
7370
):
7471
regr.fit(X, y)
7572
regr = TransformedTargetRegressor(
7673
regressor=LinearRegression(), func=np.sqrt, inverse_func=np.log
7774
)
7875
regr.set_params(check_inverse=False)
79-
assert_no_warnings(regr.fit, X, y)
76+
77+
with warnings.catch_warnings():
78+
warnings.simplefilter("error", UserWarning)
79+
regr.fit(X, y)
8080

8181

8282
def _check_standard_scaled(y, y_pred):

sklearn/metrics/tests/test_classification.py

Lines changed: 109 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
assert_almost_equal,
4646
assert_array_almost_equal,
4747
assert_array_equal,
48-
assert_no_warnings,
4948
ignore_warnings,
5049
)
5150
from sklearn.utils.extmath import _nanaverage
@@ -266,24 +265,24 @@ def test_precision_recall_f1_score_binary():
266265
# individual scoring function that can be used for grid search: in the
267266
# binary class case the score is the value of the measure for the positive
268267
# class (e.g. label == 1). This is deprecated for average != 'binary'.
269-
for kwargs, my_assert in [
270-
({}, assert_no_warnings),
271-
({"average": "binary"}, assert_no_warnings),
272-
]:
273-
ps = my_assert(precision_score, y_true, y_pred, **kwargs)
274-
assert_array_almost_equal(ps, 0.85, 2)
268+
for kwargs in [{}, {"average": "binary"}]:
269+
with warnings.catch_warnings():
270+
warnings.simplefilter("error")
275271

276-
rs = my_assert(recall_score, y_true, y_pred, **kwargs)
277-
assert_array_almost_equal(rs, 0.68, 2)
272+
ps = precision_score(y_true, y_pred, **kwargs)
273+
assert_array_almost_equal(ps, 0.85, 2)
278274

279-
fs = my_assert(f1_score, y_true, y_pred, **kwargs)
280-
assert_array_almost_equal(fs, 0.76, 2)
275+
rs = recall_score(y_true, y_pred, **kwargs)
276+
assert_array_almost_equal(rs, 0.68, 2)
281277

282-
assert_almost_equal(
283-
my_assert(fbeta_score, y_true, y_pred, beta=2, **kwargs),
284-
(1 + 2**2) * ps * rs / (2**2 * ps + rs),
285-
2,
286-
)
278+
fs = f1_score(y_true, y_pred, **kwargs)
279+
assert_array_almost_equal(fs, 0.76, 2)
280+
281+
assert_almost_equal(
282+
fbeta_score(y_true, y_pred, beta=2, **kwargs),
283+
(1 + 2**2) * ps * rs / (2**2 * ps + rs),
284+
2,
285+
)
287286

288287

289288
@ignore_warnings
@@ -1919,22 +1918,23 @@ def test_precision_recall_f1_no_labels(beta, average, zero_division):
19191918
y_true = np.zeros((20, 3))
19201919
y_pred = np.zeros_like(y_true)
19211920

1922-
p, r, f, s = assert_no_warnings(
1923-
precision_recall_fscore_support,
1924-
y_true,
1925-
y_pred,
1926-
average=average,
1927-
beta=beta,
1928-
zero_division=zero_division,
1929-
)
1930-
fbeta = assert_no_warnings(
1931-
fbeta_score,
1932-
y_true,
1933-
y_pred,
1934-
beta=beta,
1935-
average=average,
1936-
zero_division=zero_division,
1937-
)
1921+
with warnings.catch_warnings():
1922+
warnings.simplefilter("error")
1923+
1924+
p, r, f, s = precision_recall_fscore_support(
1925+
y_true,
1926+
y_pred,
1927+
average=average,
1928+
beta=beta,
1929+
zero_division=zero_division,
1930+
)
1931+
fbeta = fbeta_score(
1932+
y_true,
1933+
y_pred,
1934+
beta=beta,
1935+
average=average,
1936+
zero_division=zero_division,
1937+
)
19381938
assert s is None
19391939

19401940
# if zero_division = nan, check that all metrics are nan and exit
@@ -1984,17 +1984,20 @@ def test_precision_recall_f1_no_labels_average_none(zero_division):
19841984
# |y_i| = [0, 0, 0]
19851985
# |y_hat_i| = [0, 0, 0]
19861986

1987-
p, r, f, s = assert_no_warnings(
1988-
precision_recall_fscore_support,
1989-
y_true,
1990-
y_pred,
1991-
average=None,
1992-
beta=1.0,
1993-
zero_division=zero_division,
1994-
)
1995-
fbeta = assert_no_warnings(
1996-
fbeta_score, y_true, y_pred, beta=1.0, average=None, zero_division=zero_division
1997-
)
1987+
with warnings.catch_warnings():
1988+
warnings.simplefilter("error")
1989+
1990+
p, r, f, s = precision_recall_fscore_support(
1991+
y_true,
1992+
y_pred,
1993+
average=None,
1994+
beta=1.0,
1995+
zero_division=zero_division,
1996+
)
1997+
fbeta = fbeta_score(
1998+
y_true, y_pred, beta=1.0, average=None, zero_division=zero_division
1999+
)
2000+
19982001
zero_division = np.float64(zero_division)
19992002
assert_array_almost_equal(p, [zero_division, zero_division, zero_division], 2)
20002003
assert_array_almost_equal(r, [zero_division, zero_division, zero_division], 2)
@@ -2138,59 +2141,57 @@ def test_prf_warnings():
21382141

21392142
@pytest.mark.parametrize("zero_division", [0, 1, np.nan])
21402143
def test_prf_no_warnings_if_zero_division_set(zero_division):
2141-
# average of per-label scores
2142-
f = precision_recall_fscore_support
2143-
for average in [None, "weighted", "macro"]:
2144-
assert_no_warnings(
2145-
f, [0, 1, 2], [1, 1, 2], average=average, zero_division=zero_division
2146-
)
2144+
with warnings.catch_warnings():
2145+
warnings.simplefilter("error")
21472146

2148-
assert_no_warnings(
2149-
f, [1, 1, 2], [0, 1, 2], average=average, zero_division=zero_division
2150-
)
2147+
# average of per-label scores
2148+
for average in [None, "weighted", "macro"]:
2149+
precision_recall_fscore_support(
2150+
[0, 1, 2], [1, 1, 2], average=average, zero_division=zero_division
2151+
)
21512152

2152-
# average of per-sample scores
2153-
assert_no_warnings(
2154-
f,
2155-
np.array([[1, 0], [1, 0]]),
2156-
np.array([[1, 0], [0, 0]]),
2157-
average="samples",
2158-
zero_division=zero_division,
2159-
)
2153+
precision_recall_fscore_support(
2154+
[1, 1, 2], [0, 1, 2], average=average, zero_division=zero_division
2155+
)
21602156

2161-
assert_no_warnings(
2162-
f,
2163-
np.array([[1, 0], [0, 0]]),
2164-
np.array([[1, 0], [1, 0]]),
2165-
average="samples",
2166-
zero_division=zero_division,
2167-
)
2157+
# average of per-sample scores
2158+
precision_recall_fscore_support(
2159+
np.array([[1, 0], [1, 0]]),
2160+
np.array([[1, 0], [0, 0]]),
2161+
average="samples",
2162+
zero_division=zero_division,
2163+
)
21682164

2169-
# single score: micro-average
2170-
assert_no_warnings(
2171-
f,
2172-
np.array([[1, 1], [1, 1]]),
2173-
np.array([[0, 0], [0, 0]]),
2174-
average="micro",
2175-
zero_division=zero_division,
2176-
)
2165+
precision_recall_fscore_support(
2166+
np.array([[1, 0], [0, 0]]),
2167+
np.array([[1, 0], [1, 0]]),
2168+
average="samples",
2169+
zero_division=zero_division,
2170+
)
21772171

2178-
assert_no_warnings(
2179-
f,
2180-
np.array([[0, 0], [0, 0]]),
2181-
np.array([[1, 1], [1, 1]]),
2182-
average="micro",
2183-
zero_division=zero_division,
2184-
)
2172+
# single score: micro-average
2173+
precision_recall_fscore_support(
2174+
np.array([[1, 1], [1, 1]]),
2175+
np.array([[0, 0], [0, 0]]),
2176+
average="micro",
2177+
zero_division=zero_division,
2178+
)
21852179

2186-
# single positive label
2187-
assert_no_warnings(
2188-
f, [1, 1], [-1, -1], average="binary", zero_division=zero_division
2189-
)
2180+
precision_recall_fscore_support(
2181+
np.array([[0, 0], [0, 0]]),
2182+
np.array([[1, 1], [1, 1]]),
2183+
average="micro",
2184+
zero_division=zero_division,
2185+
)
21902186

2191-
assert_no_warnings(
2192-
f, [-1, -1], [1, 1], average="binary", zero_division=zero_division
2193-
)
2187+
# single positive label
2188+
precision_recall_fscore_support(
2189+
[1, 1], [-1, -1], average="binary", zero_division=zero_division
2190+
)
2191+
2192+
precision_recall_fscore_support(
2193+
[-1, -1], [1, 1], average="binary", zero_division=zero_division
2194+
)
21942195

21952196
with warnings.catch_warnings(record=True) as record:
21962197
warnings.simplefilter("always")
@@ -2202,13 +2203,16 @@ def test_prf_no_warnings_if_zero_division_set(zero_division):
22022203

22032204
@pytest.mark.parametrize("zero_division", ["warn", 0, 1, np.nan])
22042205
def test_recall_warnings(zero_division):
2205-
assert_no_warnings(
2206-
recall_score,
2207-
np.array([[1, 1], [1, 1]]),
2208-
np.array([[0, 0], [0, 0]]),
2209-
average="micro",
2210-
zero_division=zero_division,
2211-
)
2206+
with warnings.catch_warnings():
2207+
warnings.simplefilter("error")
2208+
2209+
recall_score(
2210+
np.array([[1, 1], [1, 1]]),
2211+
np.array([[0, 0], [0, 0]]),
2212+
average="micro",
2213+
zero_division=zero_division,
2214+
)
2215+
22122216
with warnings.catch_warnings(record=True) as record:
22132217
warnings.simplefilter("always")
22142218
recall_score(
@@ -2266,13 +2270,15 @@ def test_precision_warnings(zero_division):
22662270
" this behavior."
22672271
)
22682272

2269-
assert_no_warnings(
2270-
precision_score,
2271-
np.array([[0, 0], [0, 0]]),
2272-
np.array([[1, 1], [1, 1]]),
2273-
average="micro",
2274-
zero_division=zero_division,
2275-
)
2273+
with warnings.catch_warnings():
2274+
warnings.simplefilter("error")
2275+
2276+
precision_score(
2277+
np.array([[0, 0], [0, 0]]),
2278+
np.array([[1, 1], [1, 1]]),
2279+
average="micro",
2280+
zero_division=zero_division,
2281+
)
22762282

22772283

22782284
@pytest.mark.parametrize("zero_division", ["warn", 0, 1, np.nan])

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
import warnings
23

34
import numpy as np
45
import pytest
@@ -796,7 +797,9 @@ def test_encoder_dtypes_pandas():
796797
def test_one_hot_encoder_warning():
797798
enc = OneHotEncoder()
798799
X = [["Male", 1], ["Female", 3]]
799-
np.testing.assert_no_warnings(enc.fit_transform, X)
800+
with warnings.catch_warnings():
801+
warnings.simplefilter("error")
802+
enc.fit_transform(X)
800803

801804

802805
@pytest.mark.parametrize("missing_value", [np.nan, None, float("nan")])

sklearn/tests/test_base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from sklearn.utils._testing import (
3535
_convert_container,
3636
assert_array_equal,
37-
assert_no_warnings,
3837
ignore_warnings,
3938
)
4039

@@ -472,7 +471,10 @@ def test_pickle_version_warning_is_not_raised_with_matching_version():
472471
tree = DecisionTreeClassifier().fit(iris.data, iris.target)
473472
tree_pickle = pickle.dumps(tree)
474473
assert b"_sklearn_version" in tree_pickle
475-
tree_restored = assert_no_warnings(pickle.loads, tree_pickle)
474+
475+
with warnings.catch_warnings():
476+
warnings.simplefilter("error")
477+
tree_restored = pickle.loads(tree_pickle)
476478

477479
# test that we can predict with the restored decision tree classifier
478480
score_of_original = tree.score(iris.data, iris.target)
@@ -542,7 +544,11 @@ def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator():
542544
try:
543545
module_backup = TreeNoVersion.__module__
544546
TreeNoVersion.__module__ = "notsklearn"
545-
assert_no_warnings(pickle.loads, tree_pickle_noversion)
547+
548+
with warnings.catch_warnings():
549+
warnings.simplefilter("error")
550+
551+
pickle.loads(tree_pickle_noversion)
546552
finally:
547553
TreeNoVersion.__module__ = module_backup
548554

sklearn/utils/_testing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
assert_array_almost_equal,
3333
assert_array_equal,
3434
assert_array_less,
35-
assert_no_warnings,
3635
)
3736

3837
import sklearn
@@ -61,7 +60,6 @@
6160
"assert_approx_equal",
6261
"assert_allclose",
6362
"assert_run_python_script_without_output",
64-
"assert_no_warnings",
6563
"SkipTest",
6664
]
6765

0 commit comments

Comments
 (0)