40
40
]
41
41
42
42
43
- def test_error_messages_on_wrong_input ():
44
- for score_func in score_funcs :
45
- expected = (
46
- r"Found input variables with inconsistent numbers of samples: \[2, 3\]"
47
- )
48
- with pytest .raises (ValueError , match = expected ):
49
- score_func ([0 , 1 ], [1 , 1 , 1 ])
43
+ @pytest .mark .parametrize ("score_func" , score_funcs )
44
+ def test_error_messages_on_wrong_input (score_func ):
45
+ expected = r"Found input variables with inconsistent numbers of samples: \[2, 3\]"
46
+ with pytest .raises (ValueError , match = expected ):
47
+ score_func ([0 , 1 ], [1 , 1 , 1 ])
50
48
51
- expected = r"labels_true must be 1D: shape is \(2"
52
- with pytest .raises (ValueError , match = expected ):
53
- score_func ([[0 , 1 ], [1 , 0 ]], [1 , 1 , 1 ])
49
+ expected = r"labels_true must be 1D: shape is \(2"
50
+ with pytest .raises (ValueError , match = expected ):
51
+ score_func ([[0 , 1 ], [1 , 0 ]], [1 , 1 , 1 ])
54
52
55
- expected = r"labels_pred must be 1D: shape is \(2"
56
- with pytest .raises (ValueError , match = expected ):
57
- score_func ([0 , 1 , 0 ], [[1 , 1 ], [0 , 0 ]])
53
+ expected = r"labels_pred must be 1D: shape is \(2"
54
+ with pytest .raises (ValueError , match = expected ):
55
+ score_func ([0 , 1 , 0 ], [[1 , 1 ], [0 , 0 ]])
58
56
59
57
60
58
def test_generalized_average ():
@@ -67,39 +65,50 @@ def test_generalized_average():
67
65
assert means [0 ] == means [1 ] == means [2 ] == means [3 ]
68
66
69
67
70
- def test_perfect_matches ():
71
- for score_func in score_funcs :
72
- assert score_func ([], []) == pytest .approx (1.0 )
73
- assert score_func ([0 ], [1 ]) == pytest .approx (1.0 )
74
- assert score_func ([0 , 0 , 0 ], [0 , 0 , 0 ]) == pytest .approx (1.0 )
75
- assert score_func ([0 , 1 , 0 ], [42 , 7 , 42 ]) == pytest .approx (1.0 )
76
- assert score_func ([0.0 , 1.0 , 0.0 ], [42.0 , 7.0 , 42.0 ]) == pytest .approx (1.0 )
77
- assert score_func ([0.0 , 1.0 , 2.0 ], [42.0 , 7.0 , 2.0 ]) == pytest .approx (1.0 )
78
- assert score_func ([0 , 1 , 2 ], [42 , 7 , 2 ]) == pytest .approx (1.0 )
79
- score_funcs_with_changing_means = [
68
+ @pytest .mark .parametrize ("score_func" , score_funcs )
69
+ def test_perfect_matches (score_func ):
70
+ assert score_func ([], []) == pytest .approx (1.0 )
71
+ assert score_func ([0 ], [1 ]) == pytest .approx (1.0 )
72
+ assert score_func ([0 , 0 , 0 ], [0 , 0 , 0 ]) == pytest .approx (1.0 )
73
+ assert score_func ([0 , 1 , 0 ], [42 , 7 , 42 ]) == pytest .approx (1.0 )
74
+ assert score_func ([0.0 , 1.0 , 0.0 ], [42.0 , 7.0 , 42.0 ]) == pytest .approx (1.0 )
75
+ assert score_func ([0.0 , 1.0 , 2.0 ], [42.0 , 7.0 , 2.0 ]) == pytest .approx (1.0 )
76
+ assert score_func ([0 , 1 , 2 ], [42 , 7 , 2 ]) == pytest .approx (1.0 )
77
+
78
+
79
+ @pytest .mark .parametrize (
80
+ "score_func" ,
81
+ [
80
82
normalized_mutual_info_score ,
81
83
adjusted_mutual_info_score ,
82
- ]
83
- means = {"min" , "geometric" , "arithmetic" , "max" }
84
- for score_func in score_funcs_with_changing_means :
85
- for mean in means :
86
- assert score_func ([], [], average_method = mean ) == pytest .approx (1.0 )
87
- assert score_func ([0 ], [1 ], average_method = mean ) == pytest .approx (1.0 )
88
- assert score_func (
89
- [0 , 0 , 0 ], [0 , 0 , 0 ], average_method = mean
90
- ) == pytest .approx (1.0 )
91
- assert score_func (
92
- [0 , 1 , 0 ], [42 , 7 , 42 ], average_method = mean
93
- ) == pytest .approx (1.0 )
94
- assert score_func (
95
- [0.0 , 1.0 , 0.0 ], [42.0 , 7.0 , 42.0 ], average_method = mean
96
- ) == pytest .approx (1.0 )
97
- assert score_func (
98
- [0.0 , 1.0 , 2.0 ], [42.0 , 7.0 , 2.0 ], average_method = mean
99
- ) == pytest .approx (1.0 )
100
- assert score_func (
101
- [0 , 1 , 2 ], [42 , 7 , 2 ], average_method = mean
102
- ) == pytest .approx (1.0 )
84
+ ],
85
+ )
86
+ @pytest .mark .parametrize ("average_method" , ["min" , "geometric" , "arithmetic" , "max" ])
87
+ def test_perfect_matches_with_changing_means (score_func , average_method ):
88
+ assert score_func ([], [], average_method = average_method ) == pytest .approx (1.0 )
89
+ assert score_func ([0 ], [1 ], average_method = average_method ) == pytest .approx (1.0 )
90
+ assert score_func (
91
+ [0 , 0 , 0 ], [0 , 0 , 0 ], average_method = average_method
92
+ ) == pytest .approx (1.0 )
93
+ assert score_func (
94
+ [0 , 1 , 0 ], [42 , 7 , 42 ], average_method = average_method
95
+ ) == pytest .approx (1.0 )
96
+ assert score_func (
97
+ [0.0 , 1.0 , 0.0 ], [42.0 , 7.0 , 42.0 ], average_method = average_method
98
+ ) == pytest .approx (1.0 )
99
+ assert score_func (
100
+ [0.0 , 1.0 , 2.0 ], [42.0 , 7.0 , 2.0 ], average_method = average_method
101
+ ) == pytest .approx (1.0 )
102
+ assert score_func (
103
+ [0 , 1 , 2 ], [42 , 7 , 2 ], average_method = average_method
104
+ ) == pytest .approx (1.0 )
105
+ # Non-regression tests for: https://github.com/scikit-learn/scikit-learn/issues/30950
106
+ assert score_func ([0 , 1 ], [0 , 1 ], average_method = average_method ) == pytest .approx (
107
+ 1.0
108
+ )
109
+ assert score_func (
110
+ [0 , 1 , 2 , 3 ], [0 , 1 , 2 , 3 ], average_method = average_method
111
+ ) == pytest .approx (1.0 )
103
112
104
113
105
114
def test_homogeneous_but_not_complete_labeling ():
@@ -306,12 +315,13 @@ def test_exactly_zero_info_score():
306
315
labels_a , labels_b = (np .ones (i , dtype = int ), np .arange (i , dtype = int ))
307
316
assert normalized_mutual_info_score (labels_a , labels_b ) == pytest .approx (0.0 )
308
317
assert v_measure_score (labels_a , labels_b ) == pytest .approx (0.0 )
309
- assert adjusted_mutual_info_score (labels_a , labels_b ) == pytest . approx ( 0.0 )
318
+ assert adjusted_mutual_info_score (labels_a , labels_b ) == 0.0
310
319
assert normalized_mutual_info_score (labels_a , labels_b ) == pytest .approx (0.0 )
311
320
for method in ["min" , "geometric" , "arithmetic" , "max" ]:
312
- assert adjusted_mutual_info_score (
313
- labels_a , labels_b , average_method = method
314
- ) == pytest .approx (0.0 )
321
+ assert (
322
+ adjusted_mutual_info_score (labels_a , labels_b , average_method = method )
323
+ == 0.0
324
+ )
315
325
assert normalized_mutual_info_score (
316
326
labels_a , labels_b , average_method = method
317
327
) == pytest .approx (0.0 )
0 commit comments