Skip to content

Commit d171a3c

Browse files
authored
Preemptively fix incompatibilities with an upcoming array-api-strict release (scikit-learn#31517)
1 parent 9f86681 commit d171a3c

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,10 +1898,11 @@ def check_array_api_metric(
18981898
np.asarray(a_xp)
18991899
np.asarray(b_xp)
19001900
numpy_as_array_works = True
1901-
except (TypeError, RuntimeError):
1901+
except (TypeError, RuntimeError, ValueError):
19021902
# PyTorch with CUDA device and CuPy raise TypeError consistently.
1903-
# array-api-strict chose to raise RuntimeError instead. Exception type
1904-
# may need to be updated in the future for other libraries.
1903+
# array-api-strict chose to raise RuntimeError instead. NumPy raises
1904+
# a ValueError if the `__array__` dunder does not return an array.
1905+
# Exception type may need to be updated in the future for other libraries.
19051906
numpy_as_array_works = False
19061907

19071908
if numpy_as_array_works:

sklearn/utils/_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None):
638638
# If weights are 1D, add singleton dimensions for broadcasting
639639
shape = [1] * a.ndim
640640
shape[axis] = a.shape[axis]
641-
weights = xp.reshape(weights, shape)
641+
weights = xp.reshape(weights, tuple(shape))
642642

643643
if xp.isdtype(a.dtype, "complex floating"):
644644
raise NotImplementedError(

sklearn/utils/estimator_checks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,10 +1130,11 @@ def check_array_api_input(
11301130
# now since array-api-strict seems a bit too strict ...
11311131
numpy_asarray_works = xp.__name__ != "array_api_strict"
11321132

1133-
except (TypeError, RuntimeError):
1133+
except (TypeError, RuntimeError, ValueError):
11341134
# PyTorch with CUDA device and CuPy raise TypeError consistently.
1135-
# array-api-strict chose to raise RuntimeError instead. Exception type
1136-
# may need to be updated in the future for other libraries.
1135+
# array-api-strict chose to raise RuntimeError instead. NumPy emits
1136+
# a ValueError if `__array__` dunder does not return an array.
1137+
# Exception type may need to be updated in the future for other libraries.
11371138
numpy_asarray_works = False
11381139

11391140
if numpy_asarray_works:

0 commit comments

Comments
 (0)