Skip to content

Commit dab0842

Browse files
lithomas1OmarManzoorlesteveogrisel
authored
ENH: Make roc_curve array API compatible (scikit-learn#30878)
Co-authored-by: Omar Salman <omar.salman@arbisoft.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 107e009 commit dab0842

File tree

6 files changed

+171
-54
lines changed

6 files changed

+171
-54
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ Metrics
165165
- :func:`sklearn.metrics.precision_recall_fscore_support`
166166
- :func:`sklearn.metrics.r2_score`
167167
- :func:`sklearn.metrics.recall_score`
168+
- :func:`sklearn.metrics.roc_curve`
168169
- :func:`sklearn.metrics.root_mean_squared_error`
169170
- :func:`sklearn.metrics.root_mean_squared_log_error`
170171
- :func:`sklearn.metrics.zero_one_loss`
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`sklearn.metrics.roc_curve` now supports Array API compatible inputs.
2+
By :user:`Thomas Li <lithomas1>`

sklearn/metrics/_ranking.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@
2727
check_consistent_length,
2828
column_or_1d,
2929
)
30+
from ..utils._array_api import (
31+
_max_precision_float_dtype,
32+
get_namespace_and_device,
33+
size,
34+
)
3035
from ..utils._encode import _encode, _unique
3136
from ..utils._param_validation import Interval, StrOptions, validate_params
32-
from ..utils.extmath import stable_cumsum
3337
from ..utils.multiclass import type_of_target
3438
from ..utils.sparsefuncs import count_nonzero
3539
from ..utils.validation import _check_pos_label_consistency, _check_sample_weight
@@ -862,6 +866,8 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
862866
if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)):
863867
raise ValueError("{0} format is not supported".format(y_type))
864868

869+
xp, _, device = get_namespace_and_device(y_true, y_score, sample_weight)
870+
865871
check_consistent_length(y_true, y_score, sample_weight)
866872
y_true = column_or_1d(y_true)
867873
y_score = column_or_1d(y_score)
@@ -883,7 +889,7 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
883889
y_true = y_true == pos_label
884890

885891
# sort scores and corresponding truth values
886-
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
892+
desc_score_indices = xp.argsort(y_score, stable=True, descending=True)
887893
y_score = y_score[desc_score_indices]
888894
y_true = y_true[desc_score_indices]
889895
if sample_weight is not None:
@@ -894,17 +900,27 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
894900
# y_score typically has many tied values. Here we extract
895901
# the indices associated with the distinct values. We also
896902
# concatenate a value for the end of the curve.
897-
distinct_value_indices = np.where(np.diff(y_score))[0]
898-
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
903+
distinct_value_indices = xp.nonzero(xp.diff(y_score))[0]
904+
threshold_idxs = xp.concat(
905+
[distinct_value_indices, xp.asarray([size(y_true) - 1], device=device)]
906+
)
899907

900908
# accumulate the true positives with decreasing threshold
901-
tps = stable_cumsum(y_true * weight)[threshold_idxs]
909+
max_float_dtype = _max_precision_float_dtype(xp, device)
910+
# Perform the weighted cumulative sum using float64 precision when possible
911+
# to avoid numerical stability problem with tens of millions of very noisy
912+
# predictions:
913+
# https://github.com/scikit-learn/scikit-learn/issues/31533#issuecomment-2967062437
914+
y_true = xp.astype(y_true, max_float_dtype)
915+
tps = xp.cumulative_sum(y_true * weight, dtype=max_float_dtype)[threshold_idxs]
902916
if sample_weight is not None:
903917
# express fps as a cumsum to ensure fps is increasing even in
904918
# the presence of floating point errors
905-
fps = stable_cumsum((1 - y_true) * weight)[threshold_idxs]
919+
fps = xp.cumulative_sum((1 - y_true) * weight, dtype=max_float_dtype)[
920+
threshold_idxs
921+
]
906922
else:
907-
fps = 1 + threshold_idxs - tps
923+
fps = 1 + xp.astype(threshold_idxs, max_float_dtype) - tps
908924
return fps, tps, y_score[threshold_idxs]
909925

910926

@@ -1160,6 +1176,7 @@ def roc_curve(
11601176
>>> thresholds
11611177
array([ inf, 0.8 , 0.4 , 0.35, 0.1 ])
11621178
"""
1179+
xp, _, device = get_namespace_and_device(y_true, y_score)
11631180
fps, tps, thresholds = _binary_clf_curve(
11641181
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
11651182
)
@@ -1173,27 +1190,34 @@ def roc_curve(
11731190
# _binary_clf_curve). This keeps all cases where the point should be kept,
11741191
# but does not drop more complicated cases like fps = [1, 3, 7],
11751192
# tps = [1, 2, 4]; there is no harm in keeping too many thresholds.
1176-
if drop_intermediate and len(fps) > 2:
1177-
optimal_idxs = np.where(
1178-
np.r_[True, np.logical_or(np.diff(fps, 2), np.diff(tps, 2)), True]
1193+
if drop_intermediate and fps.shape[0] > 2:
1194+
optimal_idxs = xp.where(
1195+
xp.concat(
1196+
[
1197+
xp.asarray([True], device=device),
1198+
xp.logical_or(xp.diff(fps, 2), xp.diff(tps, 2)),
1199+
xp.asarray([True], device=device),
1200+
]
1201+
)
11791202
)[0]
11801203
fps = fps[optimal_idxs]
11811204
tps = tps[optimal_idxs]
11821205
thresholds = thresholds[optimal_idxs]
11831206

11841207
# Add an extra threshold position
11851208
# to make sure that the curve starts at (0, 0)
1186-
tps = np.r_[0, tps]
1187-
fps = np.r_[0, fps]
1209+
tps = xp.concat([xp.asarray([0.0], device=device), tps])
1210+
fps = xp.concat([xp.asarray([0.0], device=device), fps])
11881211
# get dtype of `y_score` even if it is an array-like
1189-
thresholds = np.r_[np.inf, thresholds]
1212+
thresholds = xp.astype(thresholds, _max_precision_float_dtype(xp, device))
1213+
thresholds = xp.concat([xp.asarray([xp.inf], device=device), thresholds])
11901214

11911215
if fps[-1] <= 0:
11921216
warnings.warn(
11931217
"No negative samples in y_true, false positive value should be meaningless",
11941218
UndefinedMetricWarning,
11951219
)
1196-
fpr = np.repeat(np.nan, fps.shape)
1220+
fpr = xp.full(fps.shape, xp.nan)
11971221
else:
11981222
fpr = fps / fps[-1]
11991223

@@ -1202,7 +1226,7 @@ def roc_curve(
12021226
"No positive samples in y_true, true positive value should be meaningless",
12031227
UndefinedMetricWarning,
12041228
)
1205-
tpr = np.repeat(np.nan, tps.shape)
1229+
tpr = xp.full(tps.shape, xp.nan)
12061230
else:
12071231
tpr = tps / tps[-1]
12081232

sklearn/metrics/tests/test_common.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,11 +1928,19 @@ def check_array_api_metric(
19281928
with config_context(array_api_dispatch=True):
19291929
metric_xp = metric(a_xp, b_xp, **metric_kwargs)
19301930

1931-
assert_allclose(
1932-
_convert_to_numpy(xp.asarray(metric_xp), xp),
1933-
metric_np,
1934-
atol=_atol_for_type(dtype_name),
1935-
)
1931+
def _check_metric_matches(xp_val, np_val):
1932+
assert_allclose(
1933+
_convert_to_numpy(xp.asarray(xp_val), xp),
1934+
np_val,
1935+
atol=_atol_for_type(dtype_name),
1936+
)
1937+
1938+
# Handle cases where there are multiple return values, e.g. roc_curve:
1939+
if isinstance(metric_xp, tuple):
1940+
for metric_xp_val, metric_np_val in zip(metric_xp, metric_np):
1941+
_check_metric_matches(metric_xp_val, metric_np_val)
1942+
else:
1943+
_check_metric_matches(metric_xp, metric_np)
19361944

19371945

19381946
def check_array_api_binary_classification_metric(
@@ -2269,6 +2277,9 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
22692277
check_array_api_regression_metric_multioutput,
22702278
],
22712279
sigmoid_kernel: [check_array_api_metric_pairwise],
2280+
roc_curve: [
2281+
check_array_api_binary_classification_metric,
2282+
],
22722283
}
22732284

22742285

sklearn/utils/tests/test_validation.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
deprecated,
3636
)
3737
from sklearn.utils._array_api import (
38+
_convert_to_numpy,
3839
_get_namespace_device_dtype_ids,
40+
_is_numpy_namespace,
3941
yield_namespace_device_dtype_combinations,
4042
)
4143
from sklearn.utils._mocking import (
@@ -66,6 +68,7 @@
6668
_allclose_dense_sparse,
6769
_check_feature_names_in,
6870
_check_method_params,
71+
_check_pos_label_consistency,
6972
_check_psd_eigenvalues,
7073
_check_response_method,
7174
_check_sample_weight,
@@ -1593,50 +1596,117 @@ def test_check_psd_eigenvalues_invalid(lambdas, err_type, err_msg):
15931596
_check_psd_eigenvalues(lambdas)
15941597

15951598

1596-
def test_check_sample_weight():
1597-
# check array order
1598-
sample_weight = np.ones(10)[::2]
1599-
assert not sample_weight.flags["C_CONTIGUOUS"]
1600-
sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1)))
1601-
assert sample_weight.flags["C_CONTIGUOUS"]
1602-
1599+
def _check_sample_weight_common(xp):
1600+
# Common checks between numpy/array api tests
1601+
# for check_sample_weight
16031602
# check None input
1604-
sample_weight = _check_sample_weight(None, X=np.ones((5, 2)))
1605-
assert_allclose(sample_weight, np.ones(5))
1603+
sample_weight = _check_sample_weight(None, X=xp.ones((5, 2)))
1604+
assert_allclose(_convert_to_numpy(sample_weight, xp), np.ones(5))
16061605

16071606
# check numbers input
1608-
sample_weight = _check_sample_weight(2.0, X=np.ones((5, 2)))
1609-
assert_allclose(sample_weight, 2 * np.ones(5))
1607+
sample_weight = _check_sample_weight(2.0, X=xp.ones((5, 2)))
1608+
assert_allclose(_convert_to_numpy(sample_weight, xp), 2 * np.ones(5))
16101609

16111610
# check wrong number of dimensions
16121611
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
1613-
_check_sample_weight(np.ones((2, 4)), X=np.ones((2, 2)))
1612+
_check_sample_weight(xp.ones((2, 4)), X=xp.ones((2, 2)))
16141613

16151614
# check incorrect n_samples
1616-
msg = r"sample_weight.shape == \(4,\), expected \(2,\)!"
1615+
msg = re.escape(f"sample_weight.shape == {xp.ones(4).shape}, expected (2,)!")
16171616
with pytest.raises(ValueError, match=msg):
1618-
_check_sample_weight(np.ones(4), X=np.ones((2, 2)))
1617+
_check_sample_weight(xp.ones(4), X=xp.ones((2, 2)))
16191618

16201619
# float32 dtype is preserved
1621-
X = np.ones((5, 2))
1622-
sample_weight = np.ones(5, dtype=np.float32)
1620+
X = xp.ones((5, 2))
1621+
sample_weight = xp.ones(5, dtype=xp.float32)
16231622
sample_weight = _check_sample_weight(sample_weight, X)
1624-
assert sample_weight.dtype == np.float32
1625-
1626-
# int dtype will be converted to float64 instead
1627-
X = np.ones((5, 2), dtype=int)
1628-
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
1629-
assert sample_weight.dtype == np.float64
1623+
assert sample_weight.dtype == xp.float32
16301624

16311625
# check negative weight when ensure_non_negative=True
1632-
X = np.ones((5, 2))
1633-
sample_weight = np.ones(_num_samples(X))
1626+
X = xp.ones((5, 2))
1627+
sample_weight = xp.ones(_num_samples(X))
16341628
sample_weight[-1] = -10
16351629
err_msg = "Negative values in data passed to `sample_weight`"
16361630
with pytest.raises(ValueError, match=err_msg):
16371631
_check_sample_weight(sample_weight, X, ensure_non_negative=True)
16381632

16391633

1634+
def test_check_sample_weight():
1635+
# check array order
1636+
sample_weight = np.ones(10)[::2]
1637+
assert not sample_weight.flags["C_CONTIGUOUS"]
1638+
sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1)))
1639+
assert sample_weight.flags["C_CONTIGUOUS"]
1640+
1641+
_check_sample_weight_common(np)
1642+
1643+
# int dtype will be converted to float64 instead
1644+
X = np.ones((5, 2), dtype=int)
1645+
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
1646+
assert sample_weight.dtype == np.float64
1647+
1648+
1649+
@pytest.mark.parametrize(
1650+
"array_namespace,device,dtype", yield_namespace_device_dtype_combinations()
1651+
)
1652+
def test_check_sample_weight_array_api(array_namespace, device, dtype):
1653+
xp = _array_api_for_tests(array_namespace, device)
1654+
with config_context(array_api_dispatch=True):
1655+
# check array order
1656+
sample_weight = xp.ones(10)[::2]
1657+
if _is_numpy_namespace(xp):
1658+
assert not sample_weight.flags["C_CONTIGUOUS"]
1659+
sample_weight = _check_sample_weight(sample_weight, X=xp.ones((5, 1)))
1660+
if _is_numpy_namespace(xp):
1661+
assert sample_weight.flags["C_CONTIGUOUS"]
1662+
1663+
_check_sample_weight_common(xp)
1664+
1665+
1666+
@pytest.mark.parametrize("y_true", [[0], [0, 1], [-1, 1], [1, 1, 1], [-1, -1, -1]])
1667+
def test_check_pos_label_consistency(y_true):
1668+
assert _check_pos_label_consistency(None, y_true) == 1
1669+
1670+
1671+
@pytest.mark.parametrize(
1672+
"array_namespace,device,dtype",
1673+
yield_namespace_device_dtype_combinations(),
1674+
ids=_get_namespace_device_dtype_ids,
1675+
)
1676+
@pytest.mark.parametrize("y_true", [[0], [0, 1], [-1, 1], [1, 1, 1], [-1, -1, -1]])
1677+
def test_check_pos_label_consistency_array_api(array_namespace, device, dtype, y_true):
1678+
xp = _array_api_for_tests(array_namespace, device)
1679+
with config_context(array_api_dispatch=True):
1680+
arr = xp.asarray(y_true, device=device)
1681+
assert _check_pos_label_consistency(None, arr) == 1
1682+
1683+
1684+
@pytest.mark.parametrize("y_true", [[2, 3, 4], [-10], [0, -1]])
1685+
def test_check_pos_label_consistency_invalid(y_true):
1686+
with pytest.raises(ValueError, match="y_true takes value in"):
1687+
_check_pos_label_consistency(None, y_true)
1688+
# Make sure we only raise if pos_label is None
1689+
assert _check_pos_label_consistency("a", y_true) == "a"
1690+
1691+
1692+
@pytest.mark.parametrize(
1693+
"array_namespace,device,dtype",
1694+
yield_namespace_device_dtype_combinations(),
1695+
ids=_get_namespace_device_dtype_ids,
1696+
)
1697+
@pytest.mark.parametrize("y_true", [[2, 3, 4], [-10], [0, -1]])
1698+
def test_check_pos_label_consistency_invalid_array_api(
1699+
array_namespace, device, dtype, y_true
1700+
):
1701+
xp = _array_api_for_tests(array_namespace, device)
1702+
with config_context(array_api_dispatch=True):
1703+
arr = xp.asarray(y_true, device=device)
1704+
with pytest.raises(ValueError, match="y_true takes value in"):
1705+
_check_pos_label_consistency(None, arr)
1706+
# Make sure we only raise if pos_label is None
1707+
assert _check_pos_label_consistency("a", arr) == "a"
1708+
1709+
16401710
@pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix])
16411711
def test_allclose_dense_sparse_equals(toarray):
16421712
base = np.arange(9).reshape(3, 3)

sklearn/utils/validation.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning
2121
from ..utils._array_api import (
2222
_asarray_with_order,
23+
_convert_to_numpy,
2324
_is_numpy_namespace,
2425
_max_precision_float_dtype,
2526
get_namespace,
@@ -2174,7 +2175,9 @@ def _check_sample_weight(
21742175
sample_weight : ndarray of shape (n_samples,)
21752176
Validated sample weight. It is guaranteed to be "C" contiguous.
21762177
"""
2177-
xp, _, device = get_namespace_and_device(sample_weight, X)
2178+
xp, _, device = get_namespace_and_device(
2179+
sample_weight, X, remove_types=(int, float)
2180+
)
21782181

21792182
n_samples = _num_samples(X)
21802183

@@ -2186,9 +2189,9 @@ def _check_sample_weight(
21862189
dtype = max_float_type
21872190

21882191
if sample_weight is None:
2189-
sample_weight = xp.ones(n_samples, dtype=dtype)
2192+
sample_weight = xp.ones(n_samples, dtype=dtype, device=device)
21902193
elif isinstance(sample_weight, numbers.Number):
2191-
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype)
2194+
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype, device=device)
21922195
else:
21932196
if dtype is None:
21942197
dtype = float_dtypes
@@ -2650,14 +2653,20 @@ def _check_pos_label_consistency(pos_label, y_true):
26502653
# when elements in the two arrays are not comparable.
26512654
if pos_label is None:
26522655
# Compute classes only if pos_label is not specified:
2653-
classes = np.unique(y_true)
2654-
if classes.dtype.kind in "OUS" or not (
2655-
np.array_equal(classes, [0, 1])
2656-
or np.array_equal(classes, [-1, 1])
2657-
or np.array_equal(classes, [0])
2658-
or np.array_equal(classes, [-1])
2659-
or np.array_equal(classes, [1])
2656+
xp, _, device = get_namespace_and_device(y_true)
2657+
classes = xp.unique_values(y_true)
2658+
if (
2659+
(_is_numpy_namespace(xp) and classes.dtype.kind in "OUS")
2660+
or classes.shape[0] > 2
2661+
or not (
2662+
xp.all(classes == xp.asarray([0, 1], device=device))
2663+
or xp.all(classes == xp.asarray([-1, 1], device=device))
2664+
or xp.all(classes == xp.asarray([0], device=device))
2665+
or xp.all(classes == xp.asarray([-1], device=device))
2666+
or xp.all(classes == xp.asarray([1], device=device))
2667+
)
26602668
):
2669+
classes = _convert_to_numpy(classes, xp=xp)
26612670
classes_repr = ", ".join([repr(c) for c in classes.tolist()])
26622671
raise ValueError(
26632672
f"y_true takes value in {{{classes_repr}}} and pos_label is not "

0 commit comments

Comments
 (0)