Skip to content

Commit eb29207

Browse files
PERF speedup classification_report by attaching unique values to dtype.metadata (scikit-learn#29738)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 7d0bec5 commit eb29207

File tree

6 files changed

+194
-13
lines changed

6 files changed

+194
-13
lines changed

doc/whats_new/v1.6.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,10 @@ Changelog
277277
:pr:`29210` by :user:`Marc Torrellas Socastro <marctorsoc>` and
278278
:user:`Stefanie Senger <StefanieSenger>`.
279279

280+
- |Efficiency| :func:`sklearn.metrics.classification_report` is now faster by caching
281+
classification labels.
282+
:pr:`29738` by `Adrin Jalali`_.
283+
280284
- |API| scoring="neg_max_error" should be used instead of
281285
scoring="max_error" which is now deprecated.
282286
:pr:`29462` by :user:`Farid "Freddie" Taba <artificialfintelligence>`.

sklearn/metrics/_classification.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
StrOptions,
4242
validate_params,
4343
)
44+
from ..utils._unique import attach_unique
4445
from ..utils.extmath import _nanaverage
4546
from ..utils.multiclass import type_of_target, unique_labels
4647
from ..utils.sparsefuncs import count_nonzero
@@ -216,6 +217,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
216217
"""
217218
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
218219
# Compute accuracy for each possible representation
220+
y_true, y_pred = attach_unique(y_true, y_pred)
219221
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
220222
check_consistent_length(y_true, y_pred, sample_weight)
221223
if y_type.startswith("multilabel"):
@@ -327,6 +329,7 @@ def confusion_matrix(
327329
>>> (tn, fp, fn, tp)
328330
(np.int64(0), np.int64(2), np.int64(1), np.int64(1))
329331
"""
332+
y_true, y_pred = attach_unique(y_true, y_pred)
330333
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
331334
if y_type not in ("binary", "multiclass"):
332335
raise ValueError("%s is not supported" % y_type)
@@ -516,6 +519,7 @@ def multilabel_confusion_matrix(
516519
[[2, 1],
517520
[1, 2]]])
518521
"""
522+
y_true, y_pred = attach_unique(y_true, y_pred)
519523
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
520524
if sample_weight is not None:
521525
sample_weight = column_or_1d(sample_weight)
@@ -1054,6 +1058,7 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None):
10541058
>>> matthews_corrcoef(y_true, y_pred)
10551059
np.float64(-0.33...)
10561060
"""
1061+
y_true, y_pred = attach_unique(y_true, y_pred)
10571062
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
10581063
check_consistent_length(y_true, y_pred, sample_weight)
10591064
if y_type not in {"binary", "multiclass"}:
@@ -1612,6 +1617,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
16121617
if average not in average_options and average != "binary":
16131618
raise ValueError("average has to be one of " + str(average_options))
16141619

1620+
y_true, y_pred = attach_unique(y_true, y_pred)
16151621
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
16161622
# Convert to Python primitive type to avoid NumPy type / Python str
16171623
# comparison. See https://github.com/numpy/numpy/issues/6784
@@ -2031,7 +2037,7 @@ class after being classified as negative. This is the case when the
20312037
>>> class_likelihood_ratios(y_true, y_pred, labels=["non-cat", "cat"])
20322038
(np.float64(1.5), np.float64(0.75))
20332039
"""
2034-
2040+
y_true, y_pred = attach_unique(y_true, y_pred)
20352041
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
20362042
if y_type != "binary":
20372043
raise ValueError(
@@ -2681,6 +2687,7 @@ class 2 1.00 0.67 0.80 3
26812687
<BLANKLINE>
26822688
"""
26832689

2690+
y_true, y_pred = attach_unique(y_true, y_pred)
26842691
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
26852692

26862693
if labels is None:
@@ -2869,7 +2876,7 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
28692876
>>> hamming_loss(np.array([[0, 1], [1, 1]]), np.zeros((2, 2)))
28702877
0.75
28712878
"""
2872-
2879+
y_true, y_pred = attach_unique(y_true, y_pred)
28732880
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
28742881
check_consistent_length(y_true, y_pred, sample_weight)
28752882

sklearn/utils/_array_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,11 @@ def _is_numpy_namespace(xp):
208208

209209
def _union1d(a, b, xp):
210210
if _is_numpy_namespace(xp):
211-
return xp.asarray(numpy.union1d(a, b))
211+
# avoid circular import
212+
from ._unique import cached_unique
213+
214+
a_unique, b_unique = cached_unique(a, b, xp=xp)
215+
return xp.asarray(numpy.union1d(a_unique, b_unique))
212216
assert a.ndim == b.ndim == 1
213217
return xp.unique_values(xp.concat([xp.unique_values(a), xp.unique_values(b)]))
214218

sklearn/utils/_unique.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Authors: The scikit-learn developers
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import numpy as np
5+
6+
from sklearn.utils._array_api import get_namespace
7+
8+
9+
def _attach_unique(y):
10+
"""Attach unique values of y to y and return the result.
11+
12+
The result is a view of y, and the metadata (unique) is not attached to y.
13+
"""
14+
if not isinstance(y, np.ndarray):
15+
return y
16+
try:
17+
# avoid recalculating unique in nested calls.
18+
if "unique" in y.dtype.metadata:
19+
return y
20+
except (AttributeError, TypeError):
21+
pass
22+
23+
unique = np.unique(y)
24+
unique_dtype = np.dtype(y.dtype, metadata={"unique": unique})
25+
return y.view(dtype=unique_dtype)
26+
27+
28+
def attach_unique(*ys, return_tuple=False):
29+
"""Attach unique values of ys to ys and return the results.
30+
31+
The result is a view of y, and the metadata (unique) is not attached to y.
32+
33+
IMPORTANT: The output of this function should NEVER be returned in functions.
34+
This is to avoid this pattern:
35+
36+
.. code:: python
37+
38+
y = np.array([1, 2, 3])
39+
y = attach_unique(y)
40+
y[1] = -1
41+
# now np.unique(y) will be different from cached_unique(y)
42+
43+
Parameters
44+
----------
45+
*ys : sequence of array-like
46+
Input data arrays.
47+
48+
return_tuple : bool, default=False
49+
If True, always return a tuple even if there is only one array.
50+
51+
Returns
52+
-------
53+
ys : tuple of array-like or array-like
54+
Input data with unique values attached.
55+
"""
56+
res = tuple(_attach_unique(y) for y in ys)
57+
if len(res) == 1 and not return_tuple:
58+
return res[0]
59+
return res
60+
61+
62+
def _cached_unique(y, xp=None):
63+
"""Return the unique values of y.
64+
65+
Use the cached values from dtype.metadata if present.
66+
67+
This function does NOT cache the values in y, i.e. it doesn't change y.
68+
69+
Call `attach_unique` to attach the unique values to y.
70+
"""
71+
try:
72+
if y.dtype.metadata is not None and "unique" in y.dtype.metadata:
73+
return y.dtype.metadata["unique"]
74+
except AttributeError:
75+
# in case y is not a numpy array
76+
pass
77+
xp, _ = get_namespace(y, xp=xp)
78+
return xp.unique_values(y)
79+
80+
81+
def cached_unique(*ys, xp=None):
82+
"""Return the unique values of ys.
83+
84+
Use the cached values from dtype.metadata if present.
85+
86+
This function does NOT cache the values in y, i.e. it doesn't change y.
87+
88+
Call `attach_unique` to attach the unique values to y.
89+
90+
Parameters
91+
----------
92+
*ys : sequence of array-like
93+
Input data arrays.
94+
95+
xp : module, default=None
96+
Precomputed array namespace module. When passed, typically from a caller
97+
that has already performed inspection of its own inputs, skips array
98+
namespace inspection.
99+
100+
Returns
101+
-------
102+
res : tuple of array-like or array-like
103+
Unique values of ys.
104+
"""
105+
res = tuple(_cached_unique(y, xp=xp) for y in ys)
106+
if len(res) == 1:
107+
return res[0]
108+
return res

sklearn/utils/multiclass.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212

1313
from ..utils._array_api import get_namespace
1414
from ..utils.fixes import VisibleDeprecationWarning
15+
from ._unique import attach_unique, cached_unique
1516
from .validation import _assert_all_finite, check_array
1617

1718

18-
def _unique_multiclass(y):
19-
xp, is_array_api_compliant = get_namespace(y)
19+
def _unique_multiclass(y, xp=None):
20+
xp, is_array_api_compliant = get_namespace(y, xp=xp)
2021
if hasattr(y, "__array__") or is_array_api_compliant:
21-
return xp.unique_values(xp.asarray(y))
22+
return cached_unique(xp.asarray(y), xp=xp)
2223
else:
2324
return set(y)
2425

2526

26-
def _unique_indicator(y):
27-
xp, _ = get_namespace(y)
27+
def _unique_indicator(y, xp=None):
28+
xp, _ = get_namespace(y, xp=xp)
2829
return xp.arange(
2930
check_array(y, input_name="y", accept_sparse=["csr", "csc", "coo"]).shape[1]
3031
)
@@ -69,8 +70,9 @@ def unique_labels(*ys):
6970
>>> unique_labels([1, 2, 10], [5, 11])
7071
array([ 1, 2, 5, 10, 11])
7172
"""
73+
ys = attach_unique(*ys, return_tuple=True)
7274
xp, is_array_api_compliant = get_namespace(*ys)
73-
if not ys:
75+
if len(ys) == 0:
7476
raise ValueError("No argument has been passed.")
7577
# Check that we don't mix label format
7678

@@ -104,10 +106,12 @@ def unique_labels(*ys):
104106

105107
if is_array_api_compliant:
106108
# array_api does not allow for mixed dtypes
107-
unique_ys = xp.concat([_unique_labels(y) for y in ys])
109+
unique_ys = xp.concat([_unique_labels(y, xp=xp) for y in ys])
108110
return xp.unique_values(unique_ys)
109111

110-
ys_labels = set(chain.from_iterable((i for i in _unique_labels(y)) for y in ys))
112+
ys_labels = set(
113+
chain.from_iterable((i for i in _unique_labels(y, xp=xp)) for y in ys)
114+
)
111115
# Check that we don't mix string type with number type
112116
if len(set(isinstance(label, str) for label in ys_labels)) > 1:
113117
raise ValueError("Mix of label input types (string and number)")
@@ -187,7 +191,7 @@ def is_multilabel(y):
187191
and (y.dtype.kind in "biu" or _is_integral_float(labels)) # bool, int, uint
188192
)
189193
else:
190-
labels = xp.unique_values(y)
194+
labels = cached_unique(y, xp=xp)
191195

192196
return labels.shape[0] < 3 and (
193197
xp.isdtype(y.dtype, ("bool", "signed integer", "unsigned integer"))
@@ -400,7 +404,7 @@ def type_of_target(y, input_name=""):
400404
# Check multiclass
401405
if issparse(first_row_or_val):
402406
first_row_or_val = first_row_or_val.data
403-
if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1):
407+
if cached_unique(y).shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1):
404408
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
405409
return "multiclass" + suffix
406410
else:

sklearn/utils/tests/test_unique.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
from numpy.testing import assert_array_equal
3+
4+
from sklearn.utils._unique import attach_unique, cached_unique
5+
from sklearn.utils.validation import check_array
6+
7+
8+
def test_attach_unique_attaches_unique_to_array():
9+
arr = np.array([1, 2, 2, 3, 4, 4, 5])
10+
arr_ = attach_unique(arr)
11+
assert_array_equal(arr_.dtype.metadata["unique"], np.array([1, 2, 3, 4, 5]))
12+
assert_array_equal(arr_, arr)
13+
14+
15+
def test_cached_unique_returns_cached_unique():
16+
my_dtype = np.dtype(np.float64, metadata={"unique": np.array([1, 2])})
17+
arr = np.array([1, 2, 2, 3, 4, 4, 5], dtype=my_dtype)
18+
assert_array_equal(cached_unique(arr), np.array([1, 2]))
19+
20+
21+
def test_attach_unique_not_ndarray():
22+
"""Test that when not np.ndarray, we don't touch the array."""
23+
arr = [1, 2, 2, 3, 4, 4, 5]
24+
arr_ = attach_unique(arr)
25+
assert arr_ is arr
26+
27+
28+
def test_attach_unique_returns_view():
29+
"""Test that attach_unique returns a view of the array."""
30+
arr = np.array([1, 2, 2, 3, 4, 4, 5])
31+
arr_ = attach_unique(arr)
32+
assert arr_.base is arr
33+
34+
35+
def test_attach_unique_return_tuple():
36+
"""Test return_tuple argument of the function."""
37+
arr = np.array([1, 2, 2, 3, 4, 4, 5])
38+
arr_tuple = attach_unique(arr, return_tuple=True)
39+
assert isinstance(arr_tuple, tuple)
40+
assert len(arr_tuple) == 1
41+
assert_array_equal(arr_tuple[0], arr)
42+
43+
arr_single = attach_unique(arr, return_tuple=False)
44+
assert isinstance(arr_single, np.ndarray)
45+
assert_array_equal(arr_single, arr)
46+
47+
48+
def test_check_array_keeps_unique():
49+
"""Test that check_array keeps the unique metadata."""
50+
arr = np.array([[1, 2, 2, 3, 4, 4, 5]])
51+
arr_ = attach_unique(arr)
52+
arr_ = check_array(arr_)
53+
assert_array_equal(arr_.dtype.metadata["unique"], np.array([1, 2, 3, 4, 5]))
54+
assert_array_equal(arr_, arr)

0 commit comments

Comments
 (0)