Skip to content

Commit 75cb7c3

Browse files
glevvlesteve
andauthored
FIX Fix adjusted_mutual_info_score numerical issue (scikit-learn#31065)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent 434010c commit 75cb7c3

File tree

3 files changed

+71
-50
lines changed

3 files changed

+71
-50
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- Fix :func:`metrics.adjusted_mutual_info_score` numerical issue when number of
2+
classes and samples is low.
3+
By :user:`Hleb Levitski <glevv>`

sklearn/metrics/cluster/_supervised.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,9 @@ def adjusted_mutual_info_score(
10331033
or classes.shape[0] == clusters.shape[0] == 0
10341034
):
10351035
return 1.0
1036+
# if there is only one class or one cluster return 0.0.
1037+
elif classes.shape[0] == 1 or clusters.shape[0] == 1:
1038+
return 0.0
10361039

10371040
contingency = contingency_matrix(labels_true, labels_pred, sparse=True)
10381041
# Calculate the MI for the two clusterings
@@ -1051,8 +1054,13 @@ def adjusted_mutual_info_score(
10511054
denominator = min(denominator, -np.finfo("float64").eps)
10521055
else:
10531056
denominator = max(denominator, np.finfo("float64").eps)
1054-
ami = (mi - emi) / denominator
1055-
return float(ami)
1057+
# The same applies analogously to mi and emi.
1058+
numerator = mi - emi
1059+
if numerator < 0:
1060+
numerator = min(numerator, -np.finfo("float64").eps)
1061+
else:
1062+
numerator = max(numerator, np.finfo("float64").eps)
1063+
return float(numerator / denominator)
10561064

10571065

10581066
@validate_params(

sklearn/metrics/cluster/tests/test_supervised.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,19 @@
4040
]
4141

4242

43-
def test_error_messages_on_wrong_input():
44-
for score_func in score_funcs:
45-
expected = (
46-
r"Found input variables with inconsistent numbers of samples: \[2, 3\]"
47-
)
48-
with pytest.raises(ValueError, match=expected):
49-
score_func([0, 1], [1, 1, 1])
43+
@pytest.mark.parametrize("score_func", score_funcs)
44+
def test_error_messages_on_wrong_input(score_func):
45+
expected = r"Found input variables with inconsistent numbers of samples: \[2, 3\]"
46+
with pytest.raises(ValueError, match=expected):
47+
score_func([0, 1], [1, 1, 1])
5048

51-
expected = r"labels_true must be 1D: shape is \(2"
52-
with pytest.raises(ValueError, match=expected):
53-
score_func([[0, 1], [1, 0]], [1, 1, 1])
49+
expected = r"labels_true must be 1D: shape is \(2"
50+
with pytest.raises(ValueError, match=expected):
51+
score_func([[0, 1], [1, 0]], [1, 1, 1])
5452

55-
expected = r"labels_pred must be 1D: shape is \(2"
56-
with pytest.raises(ValueError, match=expected):
57-
score_func([0, 1, 0], [[1, 1], [0, 0]])
53+
expected = r"labels_pred must be 1D: shape is \(2"
54+
with pytest.raises(ValueError, match=expected):
55+
score_func([0, 1, 0], [[1, 1], [0, 0]])
5856

5957

6058
def test_generalized_average():
@@ -67,39 +65,50 @@ def test_generalized_average():
6765
assert means[0] == means[1] == means[2] == means[3]
6866

6967

70-
def test_perfect_matches():
71-
for score_func in score_funcs:
72-
assert score_func([], []) == pytest.approx(1.0)
73-
assert score_func([0], [1]) == pytest.approx(1.0)
74-
assert score_func([0, 0, 0], [0, 0, 0]) == pytest.approx(1.0)
75-
assert score_func([0, 1, 0], [42, 7, 42]) == pytest.approx(1.0)
76-
assert score_func([0.0, 1.0, 0.0], [42.0, 7.0, 42.0]) == pytest.approx(1.0)
77-
assert score_func([0.0, 1.0, 2.0], [42.0, 7.0, 2.0]) == pytest.approx(1.0)
78-
assert score_func([0, 1, 2], [42, 7, 2]) == pytest.approx(1.0)
79-
score_funcs_with_changing_means = [
68+
@pytest.mark.parametrize("score_func", score_funcs)
69+
def test_perfect_matches(score_func):
70+
assert score_func([], []) == pytest.approx(1.0)
71+
assert score_func([0], [1]) == pytest.approx(1.0)
72+
assert score_func([0, 0, 0], [0, 0, 0]) == pytest.approx(1.0)
73+
assert score_func([0, 1, 0], [42, 7, 42]) == pytest.approx(1.0)
74+
assert score_func([0.0, 1.0, 0.0], [42.0, 7.0, 42.0]) == pytest.approx(1.0)
75+
assert score_func([0.0, 1.0, 2.0], [42.0, 7.0, 2.0]) == pytest.approx(1.0)
76+
assert score_func([0, 1, 2], [42, 7, 2]) == pytest.approx(1.0)
77+
78+
79+
@pytest.mark.parametrize(
80+
"score_func",
81+
[
8082
normalized_mutual_info_score,
8183
adjusted_mutual_info_score,
82-
]
83-
means = {"min", "geometric", "arithmetic", "max"}
84-
for score_func in score_funcs_with_changing_means:
85-
for mean in means:
86-
assert score_func([], [], average_method=mean) == pytest.approx(1.0)
87-
assert score_func([0], [1], average_method=mean) == pytest.approx(1.0)
88-
assert score_func(
89-
[0, 0, 0], [0, 0, 0], average_method=mean
90-
) == pytest.approx(1.0)
91-
assert score_func(
92-
[0, 1, 0], [42, 7, 42], average_method=mean
93-
) == pytest.approx(1.0)
94-
assert score_func(
95-
[0.0, 1.0, 0.0], [42.0, 7.0, 42.0], average_method=mean
96-
) == pytest.approx(1.0)
97-
assert score_func(
98-
[0.0, 1.0, 2.0], [42.0, 7.0, 2.0], average_method=mean
99-
) == pytest.approx(1.0)
100-
assert score_func(
101-
[0, 1, 2], [42, 7, 2], average_method=mean
102-
) == pytest.approx(1.0)
84+
],
85+
)
86+
@pytest.mark.parametrize("average_method", ["min", "geometric", "arithmetic", "max"])
87+
def test_perfect_matches_with_changing_means(score_func, average_method):
88+
assert score_func([], [], average_method=average_method) == pytest.approx(1.0)
89+
assert score_func([0], [1], average_method=average_method) == pytest.approx(1.0)
90+
assert score_func(
91+
[0, 0, 0], [0, 0, 0], average_method=average_method
92+
) == pytest.approx(1.0)
93+
assert score_func(
94+
[0, 1, 0], [42, 7, 42], average_method=average_method
95+
) == pytest.approx(1.0)
96+
assert score_func(
97+
[0.0, 1.0, 0.0], [42.0, 7.0, 42.0], average_method=average_method
98+
) == pytest.approx(1.0)
99+
assert score_func(
100+
[0.0, 1.0, 2.0], [42.0, 7.0, 2.0], average_method=average_method
101+
) == pytest.approx(1.0)
102+
assert score_func(
103+
[0, 1, 2], [42, 7, 2], average_method=average_method
104+
) == pytest.approx(1.0)
105+
# Non-regression tests for: https://github.com/scikit-learn/scikit-learn/issues/30950
106+
assert score_func([0, 1], [0, 1], average_method=average_method) == pytest.approx(
107+
1.0
108+
)
109+
assert score_func(
110+
[0, 1, 2, 3], [0, 1, 2, 3], average_method=average_method
111+
) == pytest.approx(1.0)
103112

104113

105114
def test_homogeneous_but_not_complete_labeling():
@@ -306,12 +315,13 @@ def test_exactly_zero_info_score():
306315
labels_a, labels_b = (np.ones(i, dtype=int), np.arange(i, dtype=int))
307316
assert normalized_mutual_info_score(labels_a, labels_b) == pytest.approx(0.0)
308317
assert v_measure_score(labels_a, labels_b) == pytest.approx(0.0)
309-
assert adjusted_mutual_info_score(labels_a, labels_b) == pytest.approx(0.0)
318+
assert adjusted_mutual_info_score(labels_a, labels_b) == 0.0
310319
assert normalized_mutual_info_score(labels_a, labels_b) == pytest.approx(0.0)
311320
for method in ["min", "geometric", "arithmetic", "max"]:
312-
assert adjusted_mutual_info_score(
313-
labels_a, labels_b, average_method=method
314-
) == pytest.approx(0.0)
321+
assert (
322+
adjusted_mutual_info_score(labels_a, labels_b, average_method=method)
323+
== 0.0
324+
)
315325
assert normalized_mutual_info_score(
316326
labels_a, labels_b, average_method=method
317327
) == pytest.approx(0.0)

0 commit comments

Comments
 (0)