Skip to content

Commit 5ced13c

Browse files
authored
ENH Add Array API compatibility for entropy (scikit-learn#29141)
1 parent a1027b5 commit 5ced13c

File tree

5 files changed

+31
-10
lines changed

5 files changed

+31
-10
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ base estimator also does:
112112
Metrics
113113
-------
114114

115+
- :func:`sklearn.metrics.cluster.entropy`
115116
- :func:`sklearn.metrics.accuracy_score`
116117
- :func:`sklearn.metrics.d2_tweedie_score`
117118
- :func:`sklearn.metrics.max_error`

doc/whats_new/v1.6.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ See :ref:`array_api` for more details.
3232

3333
**Functions:**
3434

35+
- :func:`sklearn.metrics.cluster.entropy` :pr:`29141` by :user:`Yaroslav Korobko <Tialo>`;
3536
- :func:`sklearn.metrics.d2_tweedie_score` :pr:`29207` by :user:`Emily Chen <EmilyXinyi>`;
3637
- :func:`sklearn.metrics.max_error` :pr:`29212` by :user:`Edoardo Abati <EdAbati>`;
3738
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`;

sklearn/metrics/cluster/_supervised.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
from scipy import sparse as sp
2525

26+
from ...utils._array_api import get_namespace
2627
from ...utils._param_validation import Interval, StrOptions, validate_params
2728
from ...utils.multiclass import type_of_target
2829
from ...utils.validation import check_array, check_consistent_length
@@ -1282,17 +1283,20 @@ def entropy(labels):
12821283
-----
12831284
The logarithm used is the natural logarithm (base-e).
12841285
"""
1285-
if len(labels) == 0:
1286+
xp, is_array_api_compliant = get_namespace(labels)
1287+
labels_len = labels.shape[0] if is_array_api_compliant else len(labels)
1288+
if labels_len == 0:
12861289
return 1.0
1287-
label_idx = np.unique(labels, return_inverse=True)[1]
1288-
pi = np.bincount(label_idx).astype(np.float64)
1289-
pi = pi[pi > 0]
1290+
1291+
pi = xp.astype(xp.unique_counts(labels)[1], xp.float64)
12901292

12911293
# single cluster => zero entropy
12921294
if pi.size == 1:
12931295
return 0.0
12941296

1295-
pi_sum = np.sum(pi)
1297+
pi_sum = xp.sum(pi)
12961298
# log(a / b) should be calculated as log(a) - log(b) for
12971299
# possible loss of precision
1298-
return -np.sum((pi / pi_sum) * (np.log(pi) - log(pi_sum)))
1300+
# Always convert the result as a Python scalar (on CPU) instead of a device
1301+
# specific scalar array.
1302+
return float(-xp.sum((pi / pi_sum) * (xp.log(pi) - log(pi_sum))))

sklearn/metrics/cluster/tests/test_supervised.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
66

7+
from sklearn.base import config_context
78
from sklearn.metrics.cluster import (
89
adjusted_mutual_info_score,
910
adjusted_rand_score,
@@ -22,7 +23,8 @@
2223
)
2324
from sklearn.metrics.cluster._supervised import _generalized_average, check_clusterings
2425
from sklearn.utils import assert_all_finite
25-
from sklearn.utils._testing import assert_almost_equal
26+
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
27+
from sklearn.utils._testing import _array_api_for_tests, assert_almost_equal
2628

2729
score_funcs = [
2830
adjusted_rand_score,
@@ -254,12 +256,25 @@ def test_int_overflow_mutual_info_fowlkes_mallows_score():
254256

255257

256258
def test_entropy():
257-
ent = entropy([0, 0, 42.0])
258-
assert_almost_equal(ent, 0.6365141, 5)
259+
assert_almost_equal(entropy([0, 0, 42.0]), 0.6365141, 5)
259260
assert_almost_equal(entropy([]), 1)
260261
assert entropy([1, 1, 1, 1]) == 0
261262

262263

264+
@pytest.mark.parametrize(
265+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
266+
)
267+
def test_entropy_array_api(array_namespace, device, dtype_name):
268+
xp = _array_api_for_tests(array_namespace, device)
269+
float_labels = xp.asarray(np.asarray([0, 0, 42.0], dtype=dtype_name), device=device)
270+
empty_int32_labels = xp.asarray([], dtype=xp.int32, device=device)
271+
int_labels = xp.asarray([1, 1, 1, 1], device=device)
272+
with config_context(array_api_dispatch=True):
273+
assert entropy(float_labels) == pytest.approx(0.6365141, abs=1e-5)
274+
assert entropy(empty_int32_labels) == 1
275+
assert entropy(int_labels) == 0
276+
277+
263278
def test_contingency_matrix():
264279
labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
265280
labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2])

sklearn/utils/_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def supported_float_dtypes(xp):
250250
def ensure_common_namespace_device(reference, *arrays):
251251
"""Ensure that all arrays use the same namespace and device as reference.
252252
253-
If neccessary the arrays are moved to the same namespace and device as
253+
If necessary the arrays are moved to the same namespace and device as
254254
the reference array.
255255
256256
Parameters

0 commit comments

Comments
 (0)