Skip to content

Commit 368a200

Browse files
ogriselbetatim
andauthored
TST enable non-CPU device testing via array-api-strict (scikit-learn#30090)
Co-authored-by: Tim Head <betatim@gmail.com>
1 parent 88283ee commit 368a200

File tree

8 files changed

+96
-47
lines changed

8 files changed

+96
-47
lines changed

sklearn/decomposition/_pca.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
# Authors: The scikit-learn developers
44
# SPDX-License-Identifier: BSD-3-Clause
55

6-
from math import log, sqrt
6+
from math import lgamma, log, sqrt
77
from numbers import Integral, Real
88

99
import numpy as np
1010
from scipy import linalg
1111
from scipy.sparse import issparse
1212
from scipy.sparse.linalg import svds
13-
from scipy.special import gammaln
1413

1514
from ..base import _fit_context
1615
from ..utils import check_random_state
@@ -71,8 +70,7 @@ def _assess_dimension(spectrum, rank, n_samples):
7170
pu = -rank * log(2.0)
7271
for i in range(1, rank + 1):
7372
pu += (
74-
gammaln((n_features - i + 1) / 2.0)
75-
- log(xp.pi) * (n_features - i + 1) / 2.0
73+
lgamma((n_features - i + 1) / 2.0) - log(xp.pi) * (n_features - i + 1) / 2.0
7674
)
7775

7876
pl = xp.sum(xp.log(spectrum[:rank]))

sklearn/discriminant_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def _solve_svd(self, X, y):
596596
std = xp.std(Xc, axis=0)
597597
# avoid division by zero in normalization
598598
std[std == 0] = 1.0
599-
fac = xp.asarray(1.0 / (n_samples - n_classes), dtype=X.dtype)
599+
fac = xp.asarray(1.0 / (n_samples - n_classes), dtype=X.dtype, device=device(X))
600600

601601
# 2) Within variance scaling
602602
X = xp.sqrt(fac) * (Xc / std)

sklearn/metrics/_regression.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from numbers import Real
1515

1616
import numpy as np
17-
from scipy.special import xlogy
1817

1918
from ..exceptions import UndefinedMetricWarning
2019
from ..utils._array_api import (
@@ -24,6 +23,9 @@
2423
get_namespace_and_device,
2524
size,
2625
)
26+
from ..utils._array_api import (
27+
_xlogy as xlogy,
28+
)
2729
from ..utils._param_validation import Interval, StrOptions, validate_params
2830
from ..utils.stats import _weighted_percentile
2931
from ..utils.validation import (
@@ -479,14 +481,16 @@ def mean_absolute_percentage_error(
479481
>>> mean_absolute_percentage_error(y_true, y_pred)
480482
112589990684262.48
481483
"""
482-
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
484+
xp, _, device_ = get_namespace_and_device(
485+
y_true, y_pred, sample_weight, multioutput
486+
)
483487
_, y_true, y_pred, sample_weight, multioutput = (
484488
_check_reg_targets_with_floating_dtype(
485489
y_true, y_pred, sample_weight, multioutput, xp=xp
486490
)
487491
)
488492
check_consistent_length(y_true, y_pred, sample_weight)
489-
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=y_true.dtype)
493+
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=y_true.dtype, device=device_)
490494
y_true_abs = xp.abs(y_true)
491495
mape = xp.abs(y_pred - y_true) / xp.maximum(y_true_abs, epsilon)
492496
output_errors = _average(mape, weights=sample_weight, axis=0)
@@ -1347,16 +1351,18 @@ def max_error(y_true, y_pred):
13471351

13481352
def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
13491353
"""Mean Tweedie deviance regression loss."""
1350-
xp, _ = get_namespace(y_true, y_pred)
1354+
xp, _, device_ = get_namespace_and_device(y_true, y_pred)
13511355
p = power
1352-
zero = xp.asarray(0, dtype=y_true.dtype)
13531356
if p < 0:
13541357
# 'Extreme stable', y any real number, y_pred > 0
13551358
dev = 2 * (
1356-
xp.pow(xp.where(y_true > 0, y_true, zero), xp.asarray(2 - p))
1359+
xp.pow(
1360+
xp.where(y_true > 0, y_true, 0.0),
1361+
2 - p,
1362+
)
13571363
/ ((1 - p) * (2 - p))
1358-
- y_true * xp.pow(y_pred, xp.asarray(1 - p)) / (1 - p)
1359-
+ xp.pow(y_pred, xp.asarray(2 - p)) / (2 - p)
1364+
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
1365+
+ xp.pow(y_pred, 2 - p) / (2 - p)
13601366
)
13611367
elif p == 0:
13621368
# Normal distribution, y and y_pred any real number
@@ -1369,9 +1375,9 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
13691375
dev = 2 * (xp.log(y_pred / y_true) + y_true / y_pred - 1)
13701376
else:
13711377
dev = 2 * (
1372-
xp.pow(y_true, xp.asarray(2 - p)) / ((1 - p) * (2 - p))
1373-
- y_true * xp.pow(y_pred, xp.asarray(1 - p)) / (1 - p)
1374-
+ xp.pow(y_pred, xp.asarray(2 - p)) / (2 - p)
1378+
xp.pow(y_true, 2 - p) / ((1 - p) * (2 - p))
1379+
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
1380+
+ xp.pow(y_pred, 2 - p) / (2 - p)
13751381
)
13761382
return float(_average(dev, weights=sample_weight))
13771383

sklearn/metrics/pairwise.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55

66
import itertools
7+
import math
78
import warnings
89
from functools import partial
910
from numbers import Integral, Real
@@ -596,12 +597,8 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
596597
distances = xp.empty((n_samples_X, n_samples_Y), dtype=xp.float32, device=device_)
597598

598599
if batch_size is None:
599-
x_density = (
600-
X.nnz / xp.prod(X.shape) if issparse(X) else xp.asarray(1, device=device_)
601-
)
602-
y_density = (
603-
Y.nnz / xp.prod(Y.shape) if issparse(Y) else xp.asarray(1, device=device_)
604-
)
600+
x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1
601+
y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1
605602

606603
# Allow 10% more memory than X, Y and the distance matrix take (at
607604
# least 10MiB)
@@ -621,13 +618,13 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
621618
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
622619
# xd=x_density and yd=y_density
623620
tmp = (x_density + y_density) * n_features
624-
batch_size = (-tmp + xp.sqrt(tmp**2 + 4 * maxmem)) / 2
621+
batch_size = (-tmp + math.sqrt(tmp**2 + 4 * maxmem)) / 2
625622
batch_size = max(int(batch_size), 1)
626623

627624
x_batches = gen_batches(n_samples_X, batch_size)
628625
xp_max_float = _max_precision_float_dtype(xp=xp, device=device_)
629626
for i, x_slice in enumerate(x_batches):
630-
X_chunk = xp.astype(X[x_slice], xp_max_float)
627+
X_chunk = xp.astype(X[x_slice, :], xp_max_float)
631628
if XX is None:
632629
XX_chunk = row_norms(X_chunk, squared=True)[:, None]
633630
else:
@@ -642,7 +639,7 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
642639
d = distances[y_slice, x_slice].T
643640

644641
else:
645-
Y_chunk = xp.astype(Y[y_slice], xp_max_float)
642+
Y_chunk = xp.astype(Y[y_slice, :], xp_max_float)
646643
if YY is None:
647644
YY_chunk = row_norms(Y_chunk, squared=True)[None, :]
648645
else:
@@ -1814,7 +1811,7 @@ def additive_chi2_kernel(X, Y=None):
18141811
array([[-1., -2.],
18151812
[-2., -1.]])
18161813
"""
1817-
xp, _ = get_namespace(X, Y)
1814+
xp, _, device_ = get_namespace_and_device(X, Y)
18181815
X, Y = check_pairwise_arrays(X, Y, accept_sparse=False)
18191816
if xp.any(X < 0):
18201817
raise ValueError("X contains negative values.")
@@ -1831,8 +1828,8 @@ def additive_chi2_kernel(X, Y=None):
18311828
yb = Y[None, :, :]
18321829
nom = -((xb - yb) ** 2)
18331830
denom = xb + yb
1834-
nom = xp.where(denom == 0, xp.asarray(0, dtype=dtype), nom)
1835-
denom = xp.where(denom == 0, xp.asarray(1, dtype=dtype), denom)
1831+
nom = xp.where(denom == 0, xp.asarray(0, dtype=dtype, device=device_), nom)
1832+
denom = xp.where(denom == 0, xp.asarray(1, dtype=dtype, device=device_), denom)
18361833
return xp.sum(nom / denom, axis=2)
18371834

18381835

sklearn/metrics/tests/test_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,10 +1838,10 @@ def check_array_api_metric(
18381838
np.asarray(a_xp)
18391839
np.asarray(b_xp)
18401840
numpy_as_array_works = True
1841-
except TypeError:
1841+
except (TypeError, RuntimeError):
18421842
# PyTorch with CUDA device and CuPy raise TypeError consistently.
1843-
# Exception type may need to be updated in the future for other
1844-
# libraries.
1843+
# array-api-strict chose to raise RuntimeError instead. Exception type
1844+
# may need to be updated in the future for other libraries.
18451845
numpy_as_array_works = False
18461846

18471847
if numpy_as_array_works:

sklearn/preprocessing/_data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,12 @@ def partial_fit(self, X, y=None):
492492
ensure_all_finite="allow-nan",
493493
)
494494

495+
device_ = device(X)
496+
feature_range = (
497+
xp.asarray(feature_range[0], dtype=X.dtype, device=device_),
498+
xp.asarray(feature_range[1], dtype=X.dtype, device=device_),
499+
)
500+
495501
data_min = _array_api._nanmin(X, axis=0, xp=xp)
496502
data_max = _array_api._nanmax(X, axis=0, xp=xp)
497503

sklearn/utils/_array_api.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True):
8282
):
8383
yield array_namespace, device, dtype
8484
yield array_namespace, "mps", "float32"
85+
86+
elif array_namespace == "array_api_strict":
87+
try:
88+
import array_api_strict # noqa
89+
90+
yield array_namespace, array_api_strict.Device("CPU_DEVICE"), "float64"
91+
yield array_namespace, array_api_strict.Device("device1"), "float32"
92+
except ImportError:
93+
# Those combinations will typically be skipped by pytest if
94+
# array_api_strict is not installed but we still need to see them in
95+
# the test output.
96+
yield array_namespace, "CPU_DEVICE", "float64"
97+
yield array_namespace, "device1", "float32"
8598
else:
8699
yield array_namespace, None, None
87100

@@ -582,12 +595,14 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
582595
if namespace.__name__ == "array_api_strict" and hasattr(
583596
namespace, "set_array_api_strict_flags"
584597
):
585-
namespace.set_array_api_strict_flags(api_version="2023.12")
598+
namespace.set_array_api_strict_flags(api_version="2024.12")
586599

587600
return namespace, is_array_api_compliant
588601

589602

590-
def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,)):
603+
def get_namespace_and_device(
604+
*array_list, remove_none=True, remove_types=(str,), xp=None
605+
):
591606
"""Combination into one single function of `get_namespace` and `device`.
592607
593608
Parameters
@@ -598,6 +613,10 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,))
598613
Whether to ignore None objects passed in arrays.
599614
remove_types : tuple or list, default=(str,)
600615
Types to ignore in the arrays.
616+
xp : module, default=None
617+
Precomputed array namespace module. When passed, typically from a caller
618+
that has already performed inspection of its own inputs, skips array
619+
namespace inspection.
601620
602621
Returns
603622
-------
@@ -610,16 +629,20 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,))
610629
device : device
611630
`device` object (see the "Device Support" section of the array API spec).
612631
"""
632+
skip_remove_kwargs = dict(remove_none=False, remove_types=[])
633+
613634
array_list = _remove_non_arrays(
614635
*array_list,
615636
remove_none=remove_none,
616637
remove_types=remove_types,
617638
)
639+
arrays_device = device(*array_list, **skip_remove_kwargs)
618640

619-
skip_remove_kwargs = dict(remove_none=False, remove_types=[])
641+
if xp is None:
642+
xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs)
643+
else:
644+
xp, is_array_api = xp, True
620645

621-
xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs)
622-
arrays_device = device(*array_list, **skip_remove_kwargs)
623646
if is_array_api:
624647
return xp, is_array_api, arrays_device
625648
else:
@@ -769,49 +792,66 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None):
769792
return sum_ / scale
770793

771794

795+
def _xlogy(x, y, xp=None):
796+
# TODO: Remove this once https://github.com/scipy/scipy/issues/21736 is fixed
797+
xp, _, device_ = get_namespace_and_device(x, y, xp=xp)
798+
799+
with numpy.errstate(divide="ignore", invalid="ignore"):
800+
temp = x * xp.log(y)
801+
return xp.where(x == 0.0, xp.asarray(0.0, dtype=temp.dtype, device=device_), temp)
802+
803+
772804
def _nanmin(X, axis=None, xp=None):
773805
# TODO: refactor once nan-aware reductions are standardized:
774806
# https://github.com/data-apis/array-api/issues/621
775-
xp, _ = get_namespace(X, xp=xp)
807+
xp, _, device_ = get_namespace_and_device(X, xp=xp)
776808
if _is_numpy_namespace(xp):
777809
return xp.asarray(numpy.nanmin(X, axis=axis))
778810

779811
else:
780812
mask = xp.isnan(X)
781-
X = xp.min(xp.where(mask, xp.asarray(+xp.inf, device=device(X)), X), axis=axis)
813+
X = xp.min(
814+
xp.where(mask, xp.asarray(+xp.inf, dtype=X.dtype, device=device_), X),
815+
axis=axis,
816+
)
782817
# Replace Infs from all NaN slices with NaN again
783818
mask = xp.all(mask, axis=axis)
784819
if xp.any(mask):
785-
X = xp.where(mask, xp.asarray(xp.nan), X)
820+
X = xp.where(mask, xp.asarray(xp.nan, dtype=X.dtype, device=device_), X)
786821
return X
787822

788823

789824
def _nanmax(X, axis=None, xp=None):
790825
# TODO: refactor once nan-aware reductions are standardized:
791826
# https://github.com/data-apis/array-api/issues/621
792-
xp, _ = get_namespace(X, xp=xp)
827+
xp, _, device_ = get_namespace_and_device(X, xp=xp)
793828
if _is_numpy_namespace(xp):
794829
return xp.asarray(numpy.nanmax(X, axis=axis))
795830

796831
else:
797832
mask = xp.isnan(X)
798-
X = xp.max(xp.where(mask, xp.asarray(-xp.inf, device=device(X)), X), axis=axis)
833+
X = xp.max(
834+
xp.where(mask, xp.asarray(-xp.inf, dtype=X.dtype, device=device_), X),
835+
axis=axis,
836+
)
799837
# Replace Infs from all NaN slices with NaN again
800838
mask = xp.all(mask, axis=axis)
801839
if xp.any(mask):
802-
X = xp.where(mask, xp.asarray(xp.nan), X)
840+
X = xp.where(mask, xp.asarray(xp.nan, dtype=X.dtype, device=device_), X)
803841
return X
804842

805843

806844
def _nanmean(X, axis=None, xp=None):
807845
# TODO: refactor once nan-aware reductions are standardized:
808846
# https://github.com/data-apis/array-api/issues/621
809-
xp, _ = get_namespace(X, xp=xp)
847+
xp, _, device_ = get_namespace_and_device(X, xp=xp)
810848
if _is_numpy_namespace(xp):
811849
return xp.asarray(numpy.nanmean(X, axis=axis))
812850
else:
813851
mask = xp.isnan(X)
814-
total = xp.sum(xp.where(mask, xp.asarray(0.0, device=device(X)), X), axis=axis)
852+
total = xp.sum(
853+
xp.where(mask, xp.asarray(0.0, dtype=X.dtype, device=device_), X), axis=axis
854+
)
815855
count = xp.sum(xp.astype(xp.logical_not(mask), X.dtype), axis=axis)
816856
return total / count
817857

@@ -868,6 +908,8 @@ def _convert_to_numpy(array, xp):
868908
return array.cpu().numpy()
869909
elif xp_name in {"array_api_compat.cupy", "cupy"}: # pragma: nocover
870910
return array.get()
911+
elif xp_name in {"array_api_strict"}:
912+
return numpy.asarray(xp.asarray(array, device=xp.Device("CPU_DEVICE")))
871913

872914
return numpy.asarray(array)
873915

sklearn/utils/estimator_checks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,10 +1125,10 @@ def check_array_api_input(
11251125
# now since array-api-strict seems a bit too strict ...
11261126
numpy_asarray_works = xp.__name__ != "array_api_strict"
11271127

1128-
except TypeError:
1128+
except (TypeError, RuntimeError):
11291129
# PyTorch with CUDA device and CuPy raise TypeError consistently.
1130-
# Exception type may need to be updated in the future for other
1131-
# libraries.
1130+
# array-api-strict chose to raise RuntimeError instead. Exception type
1131+
# may need to be updated in the future for other libraries.
11321132
numpy_asarray_works = False
11331133

11341134
if numpy_asarray_works:

0 commit comments

Comments
 (0)