45
45
assert_almost_equal ,
46
46
assert_array_almost_equal ,
47
47
assert_array_equal ,
48
- assert_no_warnings ,
49
48
ignore_warnings ,
50
49
)
51
50
from sklearn .utils .extmath import _nanaverage
@@ -266,24 +265,24 @@ def test_precision_recall_f1_score_binary():
266
265
# individual scoring function that can be used for grid search: in the
267
266
# binary class case the score is the value of the measure for the positive
268
267
# class (e.g. label == 1). This is deprecated for average != 'binary'.
269
- for kwargs , my_assert in [
270
- ({}, assert_no_warnings ),
271
- ({"average" : "binary" }, assert_no_warnings ),
272
- ]:
273
- ps = my_assert (precision_score , y_true , y_pred , ** kwargs )
274
- assert_array_almost_equal (ps , 0.85 , 2 )
268
+ for kwargs in [{}, {"average" : "binary" }]:
269
+ with warnings .catch_warnings ():
270
+ warnings .simplefilter ("error" )
275
271
276
- rs = my_assert ( recall_score , y_true , y_pred , ** kwargs )
277
- assert_array_almost_equal (rs , 0.68 , 2 )
272
+ ps = precision_score ( y_true , y_pred , ** kwargs )
273
+ assert_array_almost_equal (ps , 0.85 , 2 )
278
274
279
- fs = my_assert ( f1_score , y_true , y_pred , ** kwargs )
280
- assert_array_almost_equal (fs , 0.76 , 2 )
275
+ rs = recall_score ( y_true , y_pred , ** kwargs )
276
+ assert_array_almost_equal (rs , 0.68 , 2 )
281
277
282
- assert_almost_equal (
283
- my_assert (fbeta_score , y_true , y_pred , beta = 2 , ** kwargs ),
284
- (1 + 2 ** 2 ) * ps * rs / (2 ** 2 * ps + rs ),
285
- 2 ,
286
- )
278
+ fs = f1_score (y_true , y_pred , ** kwargs )
279
+ assert_array_almost_equal (fs , 0.76 , 2 )
280
+
281
+ assert_almost_equal (
282
+ fbeta_score (y_true , y_pred , beta = 2 , ** kwargs ),
283
+ (1 + 2 ** 2 ) * ps * rs / (2 ** 2 * ps + rs ),
284
+ 2 ,
285
+ )
287
286
288
287
289
288
@ignore_warnings
@@ -1919,22 +1918,23 @@ def test_precision_recall_f1_no_labels(beta, average, zero_division):
1919
1918
y_true = np .zeros ((20 , 3 ))
1920
1919
y_pred = np .zeros_like (y_true )
1921
1920
1922
- p , r , f , s = assert_no_warnings (
1923
- precision_recall_fscore_support ,
1924
- y_true ,
1925
- y_pred ,
1926
- average = average ,
1927
- beta = beta ,
1928
- zero_division = zero_division ,
1929
- )
1930
- fbeta = assert_no_warnings (
1931
- fbeta_score ,
1932
- y_true ,
1933
- y_pred ,
1934
- beta = beta ,
1935
- average = average ,
1936
- zero_division = zero_division ,
1937
- )
1921
+ with warnings .catch_warnings ():
1922
+ warnings .simplefilter ("error" )
1923
+
1924
+ p , r , f , s = precision_recall_fscore_support (
1925
+ y_true ,
1926
+ y_pred ,
1927
+ average = average ,
1928
+ beta = beta ,
1929
+ zero_division = zero_division ,
1930
+ )
1931
+ fbeta = fbeta_score (
1932
+ y_true ,
1933
+ y_pred ,
1934
+ beta = beta ,
1935
+ average = average ,
1936
+ zero_division = zero_division ,
1937
+ )
1938
1938
assert s is None
1939
1939
1940
1940
# if zero_division = nan, check that all metrics are nan and exit
@@ -1984,17 +1984,20 @@ def test_precision_recall_f1_no_labels_average_none(zero_division):
1984
1984
# |y_i| = [0, 0, 0]
1985
1985
# |y_hat_i| = [0, 0, 0]
1986
1986
1987
- p , r , f , s = assert_no_warnings (
1988
- precision_recall_fscore_support ,
1989
- y_true ,
1990
- y_pred ,
1991
- average = None ,
1992
- beta = 1.0 ,
1993
- zero_division = zero_division ,
1994
- )
1995
- fbeta = assert_no_warnings (
1996
- fbeta_score , y_true , y_pred , beta = 1.0 , average = None , zero_division = zero_division
1997
- )
1987
+ with warnings .catch_warnings ():
1988
+ warnings .simplefilter ("error" )
1989
+
1990
+ p , r , f , s = precision_recall_fscore_support (
1991
+ y_true ,
1992
+ y_pred ,
1993
+ average = None ,
1994
+ beta = 1.0 ,
1995
+ zero_division = zero_division ,
1996
+ )
1997
+ fbeta = fbeta_score (
1998
+ y_true , y_pred , beta = 1.0 , average = None , zero_division = zero_division
1999
+ )
2000
+
1998
2001
zero_division = np .float64 (zero_division )
1999
2002
assert_array_almost_equal (p , [zero_division , zero_division , zero_division ], 2 )
2000
2003
assert_array_almost_equal (r , [zero_division , zero_division , zero_division ], 2 )
@@ -2138,59 +2141,57 @@ def test_prf_warnings():
2138
2141
2139
2142
@pytest .mark .parametrize ("zero_division" , [0 , 1 , np .nan ])
2140
2143
def test_prf_no_warnings_if_zero_division_set (zero_division ):
2141
- # average of per-label scores
2142
- f = precision_recall_fscore_support
2143
- for average in [None , "weighted" , "macro" ]:
2144
- assert_no_warnings (
2145
- f , [0 , 1 , 2 ], [1 , 1 , 2 ], average = average , zero_division = zero_division
2146
- )
2144
+ with warnings .catch_warnings ():
2145
+ warnings .simplefilter ("error" )
2147
2146
2148
- assert_no_warnings (
2149
- f , [1 , 1 , 2 ], [0 , 1 , 2 ], average = average , zero_division = zero_division
2150
- )
2147
+ # average of per-label scores
2148
+ for average in [None , "weighted" , "macro" ]:
2149
+ precision_recall_fscore_support (
2150
+ [0 , 1 , 2 ], [1 , 1 , 2 ], average = average , zero_division = zero_division
2151
+ )
2151
2152
2152
- # average of per-sample scores
2153
- assert_no_warnings (
2154
- f ,
2155
- np .array ([[1 , 0 ], [1 , 0 ]]),
2156
- np .array ([[1 , 0 ], [0 , 0 ]]),
2157
- average = "samples" ,
2158
- zero_division = zero_division ,
2159
- )
2153
+ precision_recall_fscore_support (
2154
+ [1 , 1 , 2 ], [0 , 1 , 2 ], average = average , zero_division = zero_division
2155
+ )
2160
2156
2161
- assert_no_warnings (
2162
- f ,
2163
- np .array ([[1 , 0 ], [0 , 0 ]]),
2164
- np .array ([[1 , 0 ], [1 , 0 ]]),
2165
- average = "samples" ,
2166
- zero_division = zero_division ,
2167
- )
2157
+ # average of per-sample scores
2158
+ precision_recall_fscore_support (
2159
+ np .array ([[1 , 0 ], [1 , 0 ]]),
2160
+ np .array ([[1 , 0 ], [0 , 0 ]]),
2161
+ average = "samples" ,
2162
+ zero_division = zero_division ,
2163
+ )
2168
2164
2169
- # single score: micro-average
2170
- assert_no_warnings (
2171
- f ,
2172
- np .array ([[1 , 1 ], [1 , 1 ]]),
2173
- np .array ([[0 , 0 ], [0 , 0 ]]),
2174
- average = "micro" ,
2175
- zero_division = zero_division ,
2176
- )
2165
+ precision_recall_fscore_support (
2166
+ np .array ([[1 , 0 ], [0 , 0 ]]),
2167
+ np .array ([[1 , 0 ], [1 , 0 ]]),
2168
+ average = "samples" ,
2169
+ zero_division = zero_division ,
2170
+ )
2177
2171
2178
- assert_no_warnings (
2179
- f ,
2180
- np .array ([[0 , 0 ], [0 , 0 ]]),
2181
- np .array ([[1 , 1 ], [1 , 1 ]]),
2182
- average = "micro" ,
2183
- zero_division = zero_division ,
2184
- )
2172
+ # single score: micro-average
2173
+ precision_recall_fscore_support (
2174
+ np .array ([[1 , 1 ], [1 , 1 ]]),
2175
+ np .array ([[0 , 0 ], [0 , 0 ]]),
2176
+ average = "micro" ,
2177
+ zero_division = zero_division ,
2178
+ )
2185
2179
2186
- # single positive label
2187
- assert_no_warnings (
2188
- f , [1 , 1 ], [- 1 , - 1 ], average = "binary" , zero_division = zero_division
2189
- )
2180
+ precision_recall_fscore_support (
2181
+ np .array ([[0 , 0 ], [0 , 0 ]]),
2182
+ np .array ([[1 , 1 ], [1 , 1 ]]),
2183
+ average = "micro" ,
2184
+ zero_division = zero_division ,
2185
+ )
2190
2186
2191
- assert_no_warnings (
2192
- f , [- 1 , - 1 ], [1 , 1 ], average = "binary" , zero_division = zero_division
2193
- )
2187
+ # single positive label
2188
+ precision_recall_fscore_support (
2189
+ [1 , 1 ], [- 1 , - 1 ], average = "binary" , zero_division = zero_division
2190
+ )
2191
+
2192
+ precision_recall_fscore_support (
2193
+ [- 1 , - 1 ], [1 , 1 ], average = "binary" , zero_division = zero_division
2194
+ )
2194
2195
2195
2196
with warnings .catch_warnings (record = True ) as record :
2196
2197
warnings .simplefilter ("always" )
@@ -2202,13 +2203,16 @@ def test_prf_no_warnings_if_zero_division_set(zero_division):
2202
2203
2203
2204
@pytest .mark .parametrize ("zero_division" , ["warn" , 0 , 1 , np .nan ])
2204
2205
def test_recall_warnings (zero_division ):
2205
- assert_no_warnings (
2206
- recall_score ,
2207
- np .array ([[1 , 1 ], [1 , 1 ]]),
2208
- np .array ([[0 , 0 ], [0 , 0 ]]),
2209
- average = "micro" ,
2210
- zero_division = zero_division ,
2211
- )
2206
+ with warnings .catch_warnings ():
2207
+ warnings .simplefilter ("error" )
2208
+
2209
+ recall_score (
2210
+ np .array ([[1 , 1 ], [1 , 1 ]]),
2211
+ np .array ([[0 , 0 ], [0 , 0 ]]),
2212
+ average = "micro" ,
2213
+ zero_division = zero_division ,
2214
+ )
2215
+
2212
2216
with warnings .catch_warnings (record = True ) as record :
2213
2217
warnings .simplefilter ("always" )
2214
2218
recall_score (
@@ -2266,13 +2270,15 @@ def test_precision_warnings(zero_division):
2266
2270
" this behavior."
2267
2271
)
2268
2272
2269
- assert_no_warnings (
2270
- precision_score ,
2271
- np .array ([[0 , 0 ], [0 , 0 ]]),
2272
- np .array ([[1 , 1 ], [1 , 1 ]]),
2273
- average = "micro" ,
2274
- zero_division = zero_division ,
2275
- )
2273
+ with warnings .catch_warnings ():
2274
+ warnings .simplefilter ("error" )
2275
+
2276
+ precision_score (
2277
+ np .array ([[0 , 0 ], [0 , 0 ]]),
2278
+ np .array ([[1 , 1 ], [1 , 1 ]]),
2279
+ average = "micro" ,
2280
+ zero_division = zero_division ,
2281
+ )
2276
2282
2277
2283
2278
2284
@pytest .mark .parametrize ("zero_division" , ["warn" , 0 , 1 , np .nan ])
0 commit comments