Skip to content

Commit a922568

Browse files
authored
Fix tests for numpy 2 and array api compat (scikit-learn#29436)
1 parent a6a5397 commit a922568

File tree

5 files changed

+50
-29
lines changed

5 files changed

+50
-29
lines changed

build_tools/azure/pylatest_pip_openblas_pandas_environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ dependencies:
2727
- numpydoc
2828
- lightgbm
2929
- scikit-image
30+
- array-api-compat
31+
- array-api-strict

build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Generated by conda-lock.
22
# platform: linux-64
3-
# input_hash: af52e4ce613b7668e1e28daaea07461722275d345395a5eaced4e07a16998179
3+
# input_hash: 11d97b96088b6b1eaf3b774050152e7899f0a6ab757350df2efd44b2de3a5f75
44
@EXPLICIT
55
https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda#c3473ff8bdb3d124ed5ff11ec380d6f9
66
https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2024.3.11-h06a4308_0.conda#08529eb3504712baabcbda266a19feb7
@@ -24,6 +24,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/setuptools-69.5.1-py39h06a4308_0.co
2424
https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.43.0-py39h06a4308_0.conda#40bb60408c7433d767fd8c65b35bc4a0
2525
https://repo.anaconda.com/pkgs/main/linux-64/pip-24.0-py39h06a4308_0.conda#7f8ce3af15cfecd12e4dda8c5cef5fb7
2626
# pip alabaster @ https://files.pythonhosted.org/packages/32/34/d4e1c02d3bee589efb5dfa17f88ea08bdb3e3eac12bc475462aec52ed223/alabaster-0.7.16-py3-none-any.whl#sha256=b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92
27+
# pip array-api-compat @ https://files.pythonhosted.org/packages/05/ae/2f11031bb9f819f6efaaa66b720b37928fbb0087161fcbae3465ae374a18/array_api_compat-1.7.1-py3-none-any.whl#sha256=6974f51775972f39edbca39e08f1c2e43c51401c093a0fea5ac7159875095d8a
2728
# pip babel @ https://files.pythonhosted.org/packages/27/45/377f7e32a5c93d94cd56542349b34efab5ca3f9e2fd5a68c5e93169aa32d/Babel-2.15.0-py3-none-any.whl#sha256=08706bdad8d0a3413266ab61bd6c34d0c28d6e1e7badf40a2cebe67644e2e1fb
2829
# pip certifi @ https://files.pythonhosted.org/packages/1c/d5/c84e1a17bf61d4df64ca866a1c9a913874b4e9bdc131ec689a0ad013fb36/certifi-2024.7.4-py3-none-any.whl#sha256=c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90
2930
# pip charset-normalizer @ https://files.pythonhosted.org/packages/98/69/5d8751b4b670d623aa7a47bef061d69c279e9f922f6705147983aa76c3ce/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796
@@ -63,6 +64,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-24.0-py39h06a4308_0.conda#7f8ce
6364
# pip tzdata @ https://files.pythonhosted.org/packages/65/58/f9c9e6be752e9fcb8b6a0ee9fb87e6e7a1f6bcab2cdc73f02bb7ba91ada0/tzdata-2024.1-py2.py3-none-any.whl#sha256=9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252
6465
# pip urllib3 @ https://files.pythonhosted.org/packages/ca/1c/89ffc63a9605b583d5df2be791a27bc1a42b7c32bab68d3c8f2f73a98cd4/urllib3-2.2.2-py3-none-any.whl#sha256=a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472
6566
# pip zipp @ https://files.pythonhosted.org/packages/20/38/f5c473fe9b90c8debdd29ea68d5add0289f1936d6f923b6b9cc0b931194c/zipp-3.19.2-py3-none-any.whl#sha256=f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c
67+
# pip array-api-strict @ https://files.pythonhosted.org/packages/08/06/aba69bce257fd1cda0d1db616c12728af0f46878a5cc1923fcbb94201947/array_api_strict-2.0.1-py3-none-any.whl#sha256=f74cbf0d0c182fcb45c5ee7f28f9c7b77e6281610dfbbdd63be60b1a5a7872b3
6668
# pip contourpy @ https://files.pythonhosted.org/packages/31/a2/2f12e3a6e45935ff694654b710961b03310b0e1ec997ee9f416d3c873f87/contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445
6769
# pip coverage @ https://files.pythonhosted.org/packages/c4/b4/0cbc18998613f8caaec793ad5878d2450382dfac80e65d352fb7cd9cc1dc/coverage-7.5.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=dbc5958cb471e5a5af41b0ddaea96a37e74ed289535e8deca404811f6cb0bc3d
6870
# pip imageio @ https://files.pythonhosted.org/packages/3d/84/f1647217231f6cc46883e5d26e870cc3e1520d458ecd52d6df750810d53c/imageio-2.34.2-py3-none-any.whl#sha256=a0bb27ec9d5bab36a9f4835e51b21d2cb099e1f78451441f94687ff3404b79f8

build_tools/update_environments_and_lock_files.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,14 @@ def remove_from(alist, to_remove):
225225
"pip_dependencies": (
226226
remove_from(common_dependencies, ["python", "blas", "pip"])
227227
+ docstring_test_dependencies
228+
# Test with some optional dependencies
228229
+ ["lightgbm", "scikit-image"]
230+
# Test array API on CPU without PyTorch
231+
+ ["array-api-compat", "array-api-strict"]
229232
),
230233
"package_constraints": {
234+
# XXX: we would like to use the latest version of Python but this makes
235+
# the CI much slower. We need to investigate why.
231236
"python": "3.9",
232237
},
233238
},

sklearn/utils/_array_api.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -672,16 +672,10 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None):
672672
f"weights {tuple(weights.shape)} differ."
673673
)
674674

675-
if weights.ndim != 1:
676-
raise TypeError(
677-
f"1D weights expected when a.shape={tuple(a.shape)} and "
678-
f"weights.shape={tuple(weights.shape)} differ."
679-
)
680-
681-
if size(weights) != a.shape[axis]:
675+
if tuple(weights.shape) != (a.shape[axis],):
682676
raise ValueError(
683-
f"Length of weights {size(weights)} not compatible with "
684-
f" a.shape={tuple(a.shape)} and {axis=}."
677+
f"Shape of weights weights.shape={tuple(weights.shape)} must be "
678+
f"consistent with a.shape={tuple(a.shape)} and {axis=}."
685679
)
686680

687681
# If weights are 1D, add singleton dimensions for broadcasting
@@ -839,9 +833,14 @@ def _estimator_with_converted_arrays(estimator, converter):
839833
return new_estimator
840834

841835

842-
def _atol_for_type(dtype):
836+
def _atol_for_type(dtype_or_dtype_name):
843837
"""Return the absolute tolerance for a given numpy dtype."""
844-
return numpy.finfo(dtype).eps * 100
838+
if dtype_or_dtype_name is None:
839+
# If no dtype is specified when running tests for a given namespace, we
840+
# expect the same floating precision level as NumPy's default floating
841+
# point dtype.
842+
dtype_or_dtype_name = numpy.float64
843+
return numpy.finfo(dtype_or_dtype_name).eps * 100
845844

846845

847846
def indexing_dtype(xp):

sklearn/utils/tests/test_array_api.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
assert_array_equal,
3535
skip_if_array_api_compat_not_configured,
3636
)
37-
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS
37+
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS, np_version, parse_version
3838

3939

4040
@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
@@ -67,7 +67,12 @@ def test_get_namespace_ndarray_with_dispatch():
6767
with config_context(array_api_dispatch=True):
6868
xp_out, is_array_api_compliant = get_namespace(X_np)
6969
assert is_array_api_compliant
70-
assert xp_out is array_api_compat.numpy
70+
if np_version >= parse_version("2.0.0"):
71+
# NumPy 2.0+ is an array API compliant library.
72+
assert xp_out is numpy
73+
else:
74+
# Older NumPy versions require the compatibility layer.
75+
assert xp_out is array_api_compat.numpy
7176

7277

7378
@skip_if_array_api_compat_not_configured
@@ -135,7 +140,7 @@ def test_asarray_with_order_ignored():
135140

136141

137142
@pytest.mark.parametrize(
138-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
143+
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
139144
)
140145
@pytest.mark.parametrize(
141146
"weights, axis, normalize, expected",
@@ -167,19 +172,22 @@ def test_asarray_with_order_ignored():
167172
],
168173
)
169174
def test_average(
170-
array_namespace, device, dtype_name, weights, axis, normalize, expected
175+
array_namespace, device_, dtype_name, weights, axis, normalize, expected
171176
):
172-
xp = _array_api_for_tests(array_namespace, device)
177+
xp = _array_api_for_tests(array_namespace, device_)
173178
array_in = numpy.asarray([[1, 2, 3], [4, 5, 6]], dtype=dtype_name)
174-
array_in = xp.asarray(array_in, device=device)
179+
array_in = xp.asarray(array_in, device=device_)
175180
if weights is not None:
176181
weights = numpy.asarray(weights, dtype=dtype_name)
177-
weights = xp.asarray(weights, device=device)
182+
weights = xp.asarray(weights, device=device_)
178183

179184
with config_context(array_api_dispatch=True):
180185
result = _average(array_in, axis=axis, weights=weights, normalize=normalize)
181186

182-
assert getattr(array_in, "device", None) == getattr(result, "device", None)
187+
if np_version < parse_version("2.0.0") or np_version >= parse_version("2.1.0"):
188+
# NumPy 2.0 has a problem with the device attribute of scalar arrays:
189+
# https://github.com/numpy/numpy/issues/26850
190+
assert device(array_in) == device(result)
183191

184192
result = _convert_to_numpy(result, xp)
185193
assert_allclose(result, expected, atol=_atol_for_type(dtype_name))
@@ -226,14 +234,15 @@ def test_average_raises_with_wrong_dtype(array_namespace, device, dtype_name):
226234
(
227235
0,
228236
[[1, 2]],
229-
TypeError,
230-
"1D weights expected",
237+
# NumPy 2 raises ValueError, NumPy 1 raises TypeError
238+
(ValueError, TypeError),
239+
"weights", # the message is different for NumPy 1 and 2...
231240
),
232241
(
233242
0,
234243
[1, 2, 3, 4],
235244
ValueError,
236-
"Length of weights",
245+
"weights",
237246
),
238247
(0, [-1, 1], ZeroDivisionError, "Weights sum to zero, can't be normalized"),
239248
),
@@ -580,18 +589,18 @@ def test_get_namespace_and_device():
580589

581590

582591
@pytest.mark.parametrize(
583-
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
592+
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
584593
)
585594
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
586595
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
587596
@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"])
588597
def test_count_nonzero(
589-
array_namespace, device, dtype_name, csr_container, axis, sample_weight_type
598+
array_namespace, device_, dtype_name, csr_container, axis, sample_weight_type
590599
):
591600

592601
from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero
593602

594-
xp = _array_api_for_tests(array_namespace, device)
603+
xp = _array_api_for_tests(array_namespace, device_)
595604
array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
596605
if sample_weight_type == "int":
597606
sample_weight = numpy.asarray([1, 2, 2, 3, 1])
@@ -602,12 +611,16 @@ def test_count_nonzero(
602611
expected = sparse_count_nonzero(
603612
csr_container(array), axis=axis, sample_weight=sample_weight
604613
)
605-
array_xp = xp.asarray(array, device=device)
614+
array_xp = xp.asarray(array, device=device_)
606615

607616
with config_context(array_api_dispatch=True):
608617
result = _count_nonzero(
609-
array_xp, xp=xp, device=device, axis=axis, sample_weight=sample_weight
618+
array_xp, xp=xp, device=device_, axis=axis, sample_weight=sample_weight
610619
)
611620

612621
assert_allclose(_convert_to_numpy(result, xp=xp), expected)
613-
assert getattr(array_xp, "device", None) == getattr(result, "device", None)
622+
623+
if np_version < parse_version("2.0.0") or np_version >= parse_version("2.1.0"):
624+
# NumPy 2.0 has a problem with the device attribute of scalar arrays:
625+
# https://github.com/numpy/numpy/issues/26850
626+
assert device(array_xp) == device(result)

0 commit comments

Comments
 (0)