Skip to content

Commit 844b087

Browse files
authored
ENH Add Array API compatibility to zero_one_loss and accuracy_score (scikit-learn#27137)
1 parent dcf88e9 commit 844b087

File tree

9 files changed

+179
-49
lines changed

9 files changed

+179
-49
lines changed

doc/modules/array_api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ Estimators
9999
- :class:`preprocessing.MaxAbsScaler`
100100
- :class:`preprocessing.MinMaxScaler`
101101

102+
Metrics
103+
-------
104+
105+
- :func:`sklearn.metrics.accuracy_score`
106+
- :func:`sklearn.metrics.zero_one_loss`
107+
102108
Tools
103109
-----
104110

doc/whats_new/v1.4.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ Changelog
235235
:pr:`26931` by `Thomas Fan`_.
236236

237237
- |MajorFeature| :class:`preprocessing.MinMaxScaler` and :class:`preprocessing.MaxAbsScaler` now
238-
supports the `Array API <https://data-apis.org/array-api/latest/>`_. Array API
238+
support the `Array API <https://data-apis.org/array-api/latest/>`_. Array API
239239
support is considered experimental and might evolve without being subject to
240240
our usual rolling deprecation cycle policy. See
241241
:ref:`array_api` for more details. :pr:`26243` by `Tim Head`_ and :pr:`27110` by :user:`Edoardo Abati <EdAbati>`.
@@ -279,6 +279,9 @@ Changelog
279279
both axis is set to be 1 to get a square plot.
280280
:pr:`26366` by :user:`Mojdeh Rastgoo <mrastgoo>`.
281281

282+
- |Enhancement| :func:`sklearn.metrics.accuracy_score` and :func:`sklearn.metrics.zero_one_loss` now support
283+
Array API compatible inputs. :pr:`27137` by :user:`Edoardo Abati <EdAbati>`.
284+
282285
:mod:`sklearn.utils`
283286
....................
284287

sklearn/decomposition/tests/test_pca.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from sklearn.decomposition import PCA
1313
from sklearn.decomposition._pca import _assess_dimension, _infer_dimension
1414
from sklearn.utils._array_api import (
15+
_atol_for_type,
1516
_convert_to_numpy,
1617
yield_namespace_device_dtype_combinations,
1718
)
18-
from sklearn.utils._testing import assert_allclose
19+
from sklearn.utils._testing import _array_api_for_tests, assert_allclose
1920
from sklearn.utils.estimator_checks import (
20-
_array_api_for_tests,
2121
_get_check_estimator_ids,
2222
check_array_api_input_and_values,
2323
)
@@ -717,7 +717,7 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp
717717
assert_allclose(
718718
_convert_to_numpy(precision_xp, xp=xp),
719719
precision_np,
720-
atol=np.finfo(dtype).eps * 100,
720+
atol=_atol_for_type(dtype),
721721
)
722722
covariance_xp = estimator_xp.get_covariance()
723723
assert covariance_xp.shape == (4, 4)
@@ -726,7 +726,7 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp
726726
assert_allclose(
727727
_convert_to_numpy(covariance_xp, xp=xp),
728728
covariance_np,
729-
atol=np.finfo(dtype).eps * 100,
729+
atol=_atol_for_type(dtype),
730730
)
731731

732732

sklearn/metrics/_classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,7 @@ def zero_one_loss(y_true, y_pred, *, normalize=True, sample_weight=None):
10461046
>>> zero_one_loss(np.array([[0, 1], [1, 1]]), np.ones((2, 2)))
10471047
0.5
10481048
"""
1049+
xp, _ = get_namespace(y_true, y_pred)
10491050
score = accuracy_score(
10501051
y_true, y_pred, normalize=normalize, sample_weight=sample_weight
10511052
)
@@ -1054,7 +1055,7 @@ def zero_one_loss(y_true, y_pred, *, normalize=True, sample_weight=None):
10541055
return 1 - score
10551056
else:
10561057
if sample_weight is not None:
1057-
n_samples = np.sum(sample_weight)
1058+
n_samples = xp.sum(sample_weight)
10581059
else:
10591060
n_samples = _num_samples(y_true)
10601061
return n_samples - score

sklearn/metrics/tests/test_common.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import scipy.sparse as sp
88

9+
from sklearn._config import config_context
910
from sklearn.datasets import make_multilabel_classification
1011
from sklearn.metrics import (
1112
accuracy_score,
@@ -53,7 +54,12 @@
5354
from sklearn.metrics._base import _average_binary_score
5455
from sklearn.preprocessing import LabelBinarizer
5556
from sklearn.utils import shuffle
57+
from sklearn.utils._array_api import (
58+
_atol_for_type,
59+
yield_namespace_device_dtype_combinations,
60+
)
5661
from sklearn.utils._testing import (
62+
_array_api_for_tests,
5763
assert_allclose,
5864
assert_almost_equal,
5965
assert_array_equal,
@@ -1723,3 +1729,74 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
17231729
err_msg = err_msg_pos_label_1 if pos_label_default == 1 else err_msg_pos_label_None
17241730
with pytest.raises(ValueError, match=err_msg):
17251731
metric(y1, y2)
1732+
1733+
1734+
def check_array_api_metric(
1735+
metric, array_namespace, device, dtype, y_true_np, y_pred_np
1736+
):
1737+
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
1738+
y_true_xp = xp.asarray(y_true_np, device=device)
1739+
y_pred_xp = xp.asarray(y_pred_np, device=device)
1740+
1741+
metric_np = metric(y_true_np, y_pred_np)
1742+
1743+
with config_context(array_api_dispatch=True):
1744+
metric_xp = metric(y_true_xp, y_pred_xp)
1745+
1746+
assert_allclose(
1747+
metric_xp,
1748+
metric_np,
1749+
atol=_atol_for_type(dtype),
1750+
)
1751+
1752+
1753+
def check_array_api_binary_classification_metric(
1754+
metric, array_namespace, device, dtype
1755+
):
1756+
return check_array_api_metric(
1757+
metric,
1758+
array_namespace,
1759+
device,
1760+
dtype,
1761+
y_true_np=np.array([0, 0, 1, 1]),
1762+
y_pred_np=np.array([0, 1, 0, 1]),
1763+
)
1764+
1765+
1766+
def check_array_api_multiclass_classification_metric(
1767+
metric, array_namespace, device, dtype
1768+
):
1769+
return check_array_api_metric(
1770+
metric,
1771+
array_namespace,
1772+
device,
1773+
dtype,
1774+
y_true_np=np.array([0, 1, 2, 3]),
1775+
y_pred_np=np.array([0, 1, 0, 2]),
1776+
)
1777+
1778+
1779+
metric_checkers = {
1780+
accuracy_score: [
1781+
check_array_api_binary_classification_metric,
1782+
check_array_api_multiclass_classification_metric,
1783+
],
1784+
zero_one_loss: [
1785+
check_array_api_binary_classification_metric,
1786+
check_array_api_multiclass_classification_metric,
1787+
],
1788+
}
1789+
1790+
1791+
def yield_metric_checker_combinations(metric_checkers=metric_checkers):
1792+
for metric, checkers in metric_checkers.items():
1793+
for checker in checkers:
1794+
yield metric, checker
1795+
1796+
1797+
@pytest.mark.parametrize(
1798+
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
1799+
)
1800+
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
1801+
def test_array_api_compliance(metric, array_namespace, device, dtype, check_func):
1802+
check_func(metric, array_namespace, device, dtype)

sklearn/utils/_array_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
468468
sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
469469

470470
if sample_weight is not None:
471-
sample_weight = xp.asarray(sample_weight)
471+
sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype)
472472
if not xp.isdtype(sample_weight.dtype, "real floating"):
473473
sample_weight = xp.astype(sample_weight, xp.float64)
474474

@@ -590,3 +590,8 @@ def _estimator_with_converted_arrays(estimator, converter):
590590
attribute = converter(attribute)
591591
setattr(new_estimator, key, attribute)
592592
return new_estimator
593+
594+
595+
def _atol_for_type(dtype):
596+
"""Return the absolute tolerance for a given dtype."""
597+
return numpy.finfo(dtype).eps * 100

sklearn/utils/_testing.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import atexit
1414
import contextlib
1515
import functools
16+
import importlib
1617
import inspect
1718
import os
1819
import os.path as op
@@ -1047,3 +1048,43 @@ def transform(self, X, y=None):
10471048

10481049
def fit_transform(self, X, y=None):
10491050
return self.fit(X, y).transform(X, y)
1051+
1052+
1053+
def _array_api_for_tests(array_namespace, device, dtype):
1054+
try:
1055+
array_mod = importlib.import_module(array_namespace)
1056+
except ModuleNotFoundError:
1057+
raise SkipTest(
1058+
f"{array_namespace} is not installed: not checking array_api input"
1059+
)
1060+
try:
1061+
import array_api_compat # noqa
1062+
except ImportError:
1063+
raise SkipTest(
1064+
"array_api_compat is not installed: not checking array_api input"
1065+
)
1066+
1067+
# First create an array using the chosen array module and then get the
1068+
# corresponding (compatibility wrapped) array namespace based on it.
1069+
# This is because `cupy` is not the same as the compatibility wrapped
1070+
# namespace of a CuPy array.
1071+
xp = array_api_compat.get_namespace(array_mod.asarray(1))
1072+
if array_namespace == "torch" and device == "cuda" and not xp.has_cuda:
1073+
raise SkipTest("PyTorch test requires cuda, which is not available")
1074+
elif array_namespace == "torch" and device == "mps" and not xp.has_mps:
1075+
if not xp.backends.mps.is_built():
1076+
raise SkipTest(
1077+
"MPS is not available because the current PyTorch install was not "
1078+
"built with MPS enabled."
1079+
)
1080+
else:
1081+
raise SkipTest(
1082+
"MPS is not available because the current MacOS version is not 12.3+ "
1083+
"and/or you do not have an MPS-enabled device on this machine."
1084+
)
1085+
elif array_namespace in {"cupy", "cupy.array_api"}: # pragma: nocover
1086+
import cupy
1087+
1088+
if cupy.cuda.runtime.getDeviceCount() == 0:
1089+
raise SkipTest("CuPy test requires cuda, which is not available")
1090+
return xp, device, dtype

sklearn/utils/estimator_checks.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import importlib
21
import pickle
32
import re
43
import warnings
@@ -67,6 +66,7 @@
6766
)
6867
from ._testing import (
6968
SkipTest,
69+
_array_api_for_tests,
7070
_get_args,
7171
assert_allclose,
7272
assert_allclose_dense_sparse,
@@ -849,46 +849,6 @@ def _generate_sparse_matrix(X_csr):
849849
yield sparse_format + "_64", X
850850

851851

852-
def _array_api_for_tests(array_namespace, device, dtype):
853-
try:
854-
array_mod = importlib.import_module(array_namespace)
855-
except ModuleNotFoundError:
856-
raise SkipTest(
857-
f"{array_namespace} is not installed: not checking array_api input"
858-
)
859-
try:
860-
import array_api_compat # noqa
861-
except ImportError:
862-
raise SkipTest(
863-
"array_api_compat is not installed: not checking array_api input"
864-
)
865-
866-
# First create an array using the chosen array module and then get the
867-
# corresponding (compatibility wrapped) array namespace based on it.
868-
# This is because `cupy` is not the same as the compatibility wrapped
869-
# namespace of a CuPy array.
870-
xp = array_api_compat.get_namespace(array_mod.asarray(1))
871-
if array_namespace == "torch" and device == "cuda" and not xp.has_cuda:
872-
raise SkipTest("PyTorch test requires cuda, which is not available")
873-
elif array_namespace == "torch" and device == "mps" and not xp.has_mps:
874-
if not xp.backends.mps.is_built():
875-
raise SkipTest(
876-
"MPS is not available because the current PyTorch install was not "
877-
"built with MPS enabled."
878-
)
879-
else:
880-
raise SkipTest(
881-
"MPS is not available because the current MacOS version is not 12.3+ "
882-
"and/or you do not have an MPS-enabled device on this machine."
883-
)
884-
elif array_namespace in {"cupy", "cupy.array_api"}: # pragma: nocover
885-
import cupy
886-
887-
if cupy.cuda.runtime.getDeviceCount() == 0:
888-
raise SkipTest("CuPy test requires cuda, which is not available")
889-
return xp, device, dtype
890-
891-
892852
def check_array_api_input(
893853
name,
894854
estimator_orig,

sklearn/utils/tests/test_array_api.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,21 @@
99
from sklearn.utils._array_api import (
1010
_ArrayAPIWrapper,
1111
_asarray_with_order,
12+
_atol_for_type,
1213
_convert_to_numpy,
1314
_estimator_with_converted_arrays,
1415
_nanmax,
1516
_nanmin,
1617
_NumPyAPIWrapper,
18+
_weighted_sum,
1719
get_namespace,
1820
supported_float_dtypes,
21+
yield_namespace_device_dtype_combinations,
22+
)
23+
from sklearn.utils._testing import (
24+
_array_api_for_tests,
25+
skip_if_array_api_compat_not_configured,
1926
)
20-
from sklearn.utils._testing import skip_if_array_api_compat_not_configured
2127

2228
pytestmark = pytest.mark.filterwarnings(
2329
"ignore:The numpy.array_api submodule:UserWarning"
@@ -164,6 +170,37 @@ def test_asarray_with_order_ignored():
164170
assert not X_new_np.flags["F_CONTIGUOUS"]
165171

166172

173+
@pytest.mark.parametrize(
174+
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
175+
)
176+
@pytest.mark.parametrize(
177+
"sample_weight, normalize, expected",
178+
[
179+
(None, False, 10.0),
180+
(None, True, 2.5),
181+
([0.4, 0.4, 0.5, 0.7], False, 5.5),
182+
([0.4, 0.4, 0.5, 0.7], True, 2.75),
183+
([1, 2, 3, 4], False, 30.0),
184+
([1, 2, 3, 4], True, 3.0),
185+
],
186+
)
187+
def test_weighted_sum(
188+
array_namespace, device, dtype, sample_weight, normalize, expected
189+
):
190+
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
191+
sample_score = numpy.asarray([1, 2, 3, 4], dtype=dtype)
192+
sample_score = xp.asarray(sample_score, device=device)
193+
if sample_weight is not None:
194+
sample_weight = numpy.asarray(sample_weight, dtype=dtype)
195+
sample_weight = xp.asarray(sample_weight, device=device)
196+
197+
with config_context(array_api_dispatch=True):
198+
result = _weighted_sum(sample_score, sample_weight, normalize)
199+
200+
assert isinstance(result, float)
201+
assert_allclose(result, expected, atol=_atol_for_type(dtype))
202+
203+
167204
@skip_if_array_api_compat_not_configured
168205
@pytest.mark.parametrize(
169206
"library", ["numpy", "numpy.array_api", "cupy", "cupy.array_api", "torch"]

0 commit comments

Comments
 (0)