Skip to content

Commit 53ed13c

Browse files
EdAbatiOmarManzoor
andauthored
FIX: accuracy and zero_loss support for multilabel with Array API (scikit-learn#29336)
Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent 191f969 commit 53ed13c

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

doc/whats_new/v1.5.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ Changelog
4848
instead of implicitly converting those inputs as regular NumPy arrays.
4949
:pr:`29119` by :user:`Olivier Grisel`.
5050

51-
- |Fix| Fix a regression in :func:`metrics.zero_one_loss` causing an error
52-
for Array API dispatch with multilabel inputs.
53-
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>`.
51+
- |Fix| Fix a regression in :func:`metrics.accuracy_score` and in :func:`metrics.zero_one_loss`
52+
causing an error for Array API dispatch with multilabel inputs.
53+
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>` and :pr:`29336` by :user:`Edoardo Abati <EdAbati>`.
5454

5555
:mod:`sklearn.model_selection`
5656
..............................

sklearn/metrics/_classification.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
)
2929
from ..utils._array_api import (
3030
_average,
31+
_count_nonzero,
32+
_is_numpy_namespace,
3133
_union1d,
3234
get_namespace,
3335
get_namespace_and_device,
@@ -85,6 +87,7 @@ def _check_targets(y_true, y_pred):
8587
8688
y_pred : array or indicator matrix
8789
"""
90+
xp, _ = get_namespace(y_true, y_pred)
8891
check_consistent_length(y_true, y_pred)
8992
type_true = type_of_target(y_true, input_name="y_true")
9093
type_pred = type_of_target(y_pred, input_name="y_pred")
@@ -130,8 +133,13 @@ def _check_targets(y_true, y_pred):
130133
y_type = "multiclass"
131134

132135
if y_type.startswith("multilabel"):
133-
y_true = csr_matrix(y_true)
134-
y_pred = csr_matrix(y_pred)
136+
if _is_numpy_namespace(xp):
137+
# XXX: do we really want to sparse-encode multilabel indicators when
138+
# they are passed as a dense arrays? This is not possible for array
139+
# API inputs in general hence we only do it for NumPy inputs. But even
140+
# for NumPy the usefulness is questionable.
141+
y_true = csr_matrix(y_true)
142+
y_pred = csr_matrix(y_pred)
135143
y_type = "multilabel-indicator"
136144

137145
return y_type, y_true, y_pred
@@ -211,7 +219,12 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
211219
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
212220
check_consistent_length(y_true, y_pred, sample_weight)
213221
if y_type.startswith("multilabel"):
214-
differing_labels = count_nonzero(y_true - y_pred, axis=1)
222+
if _is_numpy_namespace(xp):
223+
differing_labels = count_nonzero(y_true - y_pred, axis=1)
224+
else:
225+
differing_labels = _count_nonzero(
226+
y_true - y_pred, xp=xp, device=device, axis=1
227+
)
215228
score = xp.asarray(differing_labels == 0, device=device)
216229
else:
217230
score = y_true == y_pred

sklearn/utils/_array_api.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,3 +967,20 @@ def _in1d(ar1, ar2, xp, assume_unique=False, invert=False):
967967
return ret[: ar1.shape[0]]
968968
else:
969969
return xp.take(ret, rev_idx, axis=0)
970+
971+
972+
def _count_nonzero(X, xp, device, axis=None, sample_weight=None):
973+
"""A variant of `sklearn.utils.sparsefuncs.count_nonzero` for the Array API.
974+
975+
It only supports 2D arrays.
976+
"""
977+
assert X.ndim == 2
978+
979+
weights = xp.ones_like(X, device=device)
980+
if sample_weight is not None:
981+
sample_weight = xp.asarray(sample_weight, device=device)
982+
sample_weight = xp.reshape(sample_weight, (sample_weight.shape[0], 1))
983+
weights = xp.astype(weights, sample_weight.dtype) * sample_weight
984+
985+
zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype)
986+
return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis)

sklearn/utils/tests/test_array_api.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_atol_for_type,
1414
_average,
1515
_convert_to_numpy,
16+
_count_nonzero,
1617
_estimator_with_converted_arrays,
1718
_is_numpy_namespace,
1819
_isin,
@@ -32,7 +33,7 @@
3233
assert_array_equal,
3334
skip_if_array_api_compat_not_configured,
3435
)
35-
from sklearn.utils.fixes import _IS_32BIT
36+
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS
3637

3738

3839
@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
@@ -566,3 +567,37 @@ def test_get_namespace_and_device():
566567
assert namespace is xp_torch
567568
assert is_array_api
568569
assert device == some_torch_tensor.device
570+
571+
572+
@pytest.mark.parametrize(
573+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
574+
)
575+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
576+
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
577+
@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"])
578+
def test_count_nonzero(
579+
array_namespace, device, dtype_name, csr_container, axis, sample_weight_type
580+
):
581+
582+
from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero
583+
584+
xp = _array_api_for_tests(array_namespace, device)
585+
array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
586+
if sample_weight_type == "int":
587+
sample_weight = numpy.asarray([1, 2, 2, 3, 1])
588+
elif sample_weight_type == "float":
589+
sample_weight = numpy.asarray([0.5, 1.5, 0.8, 3.2, 2.4], dtype=dtype_name)
590+
else:
591+
sample_weight = None
592+
expected = sparse_count_nonzero(
593+
csr_container(array), axis=axis, sample_weight=sample_weight
594+
)
595+
array_xp = xp.asarray(array, device=device)
596+
597+
with config_context(array_api_dispatch=True):
598+
result = _count_nonzero(
599+
array_xp, xp=xp, device=device, axis=axis, sample_weight=sample_weight
600+
)
601+
602+
assert_allclose(_convert_to_numpy(result, xp=xp), expected)
603+
assert getattr(array_xp, "device", None) == getattr(result, "device", None)

0 commit comments

Comments
 (0)