Skip to content

Commit 5aec488

Browse files
committed
Reinstate __array__ on Python 3.10/3.11
1 parent 51c2924 commit 5aec488

File tree

5 files changed

+106
-41
lines changed

5 files changed

+106
-41
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ['3.12', '3.13']
15-
numpy-version: ['1.26', '2.2', 'dev']
14+
python-version: ['3.10', '3.11', '3.12', '3.13']
15+
numpy-version: ['1.26', '2.3', 'dev']
1616
exclude:
17+
- python-version: '3.10'
18+
numpy-version: '2.3'
19+
- python-version: '3.10'
20+
numpy-version: 'dev'
1721
- python-version: '3.13'
1822
numpy-version: '1.26'
19-
23+
fail-fast: false
2024
steps:
2125
- name: Checkout array-api-strict
2226
uses: actions/checkout@v4

.github/workflows/tests.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ jobs:
66
strategy:
77
matrix:
88
python-version: ['3.10', '3.11', '3.12', '3.13']
9-
numpy-version: ['1.26', 'dev']
9+
numpy-version: ['1.26', '2.3', 'dev']
1010
exclude:
11+
- python-version: '3.10'
12+
numpy-version: '2.3'
13+
- python-version: '3.10'
14+
numpy-version: 'dev'
1115
- python-version: '3.13'
1216
numpy-version: '1.26'
13-
fail-fast: true
17+
fail-fast: false
1418
steps:
1519
- uses: actions/checkout@v4
1620
- uses: actions/setup-python@v5

array_api_strict/_array_object.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import operator
19+
import sys
1920
from collections.abc import Iterator
2021
from enum import IntEnum
2122
from types import EllipsisType, ModuleType
@@ -67,8 +68,6 @@ def __hash__(self) -> int:
6768
CPU_DEVICE = Device()
6869
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"))
6970

70-
_default = object()
71-
7271

7372
class Array:
7473
"""
@@ -149,29 +148,40 @@ def __repr__(self) -> str:
149148

150149
__str__ = __repr__
151150

152-
# `__array__` was implemented historically for compatibility, and removing it has
153-
# caused issues for some libraries (see
154-
# https://github.com/data-apis/array-api-strict/issues/67).
155-
156-
# Instead of `__array__` we now implement the buffer protocol.
157-
# Note that it makes array-apis-strict requiring python>=3.12
158151
def __buffer__(self, flags):
159152
if self._device != CPU_DEVICE:
160-
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
153+
raise RuntimeError(
154+
# NumPy swallows this exception and falls back to __array__.
155+
f"Can't extract host buffer from array on the '{self._device}' device."
156+
)
161157
return self._array.__buffer__(flags)
162158

163-
# We do not define __release_buffer__, per the discussion at
164-
# https://github.com/data-apis/array-api-strict/pull/115#pullrequestreview-2917178729
165-
166-
def __array__(self, *args, **kwds):
167-
# a stub for python < 3.12; otherwise numpy silently produces object arrays
168-
import sys
169-
minor, major = sys.version_info.minor, sys.version_info.major
170-
if major < 3 or minor < 12:
159+
# `__array__` is not part of the Array API. Ideally we want to support
160+
# `xp.asarray(Array)` exclusively through the __buffer__ protocol; however this is
161+
# only possible on Python >=3.12. Additionally, when __buffer__ raises (e.g. because
162+
# the array is not on the CPU device, NumPy will try to fall back on __array__ but,
163+
# if that doesn't exist, create a scalar numpy array of objects which contains the
164+
# array_api_strict.Array. So we can't get rid of __array__ entirely.
165+
def __array__(
166+
self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
167+
) -> npt.NDArray[Any]:
168+
if self._device != CPU_DEVICE:
169+
# We arrive here from np.asarray() on Python >=3.12 when __buffer__ raises.
170+
raise RuntimeError(
171+
f"Can't convert array on the '{self._device}' device to a "
172+
"NumPy array."
173+
)
174+
if sys.version_info >= (3, 12):
171175
raise TypeError(
172-
"Interoperation with NumPy requires python >= 3.12. Please upgrade."
176+
"The __array__ method is not supported by the Array API. "
177+
"Please use the __buffer__ interface instead."
173178
)
174179

180+
# copy keyword is new in 2.0
181+
if np.__version__[0] < '2':
182+
return np.asarray(self._array, dtype=dtype)
183+
return np.asarray(self._array, dtype=dtype, copy=copy)
184+
175185
# These are various helper functions to make the array behavior match the
176186
# spec in places where it either deviates from or is more strict than
177187
# NumPy behavior

array_api_strict/tests/test_array_object.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -559,36 +559,54 @@ def test_array_properties():
559559
assert b.mT.shape == (3, 2)
560560

561561

562-
@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312,
563-
reason="array conversion relies on buffer protocol, and "
564-
"requires python >= 3.12"
565-
)
566562
def test_array_conversion():
567563
# Check that arrays on the CPU device can be converted to NumPy
568564
# but arrays on other devices can't. Note this is testing the logic in
569-
# __array__, which is only used in asarray when converting lists of
570-
# arrays.
565+
# __array__ on Python 3.10~3.11 and in __buffer__ on Python >=3.12.
571566
a = ones((2, 3))
572-
np.asarray(a)
567+
na = np.asarray(a)
568+
assert na.shape == (2, 3)
569+
assert na.dtype == np.float64
570+
na[0, 0] = 10
571+
assert a[0, 0] == 10 # return view when possible
572+
573+
a = arange(5, dtype=uint8)
574+
na = np.asarray(a)
575+
assert na.dtype == np.uint8
573576

574577
for device in ("device1", "device2"):
575578
a = ones((2, 3), device=array_api_strict.Device(device))
576-
with pytest.raises((RuntimeError, ValueError)):
579+
with pytest.raises(RuntimeError, match=device):
577580
np.asarray(a)
578581

579-
# __buffer__ should work for now for conversion to numpy
580-
a = ones((2, 3))
581-
na = np.array(a)
582-
assert na.shape == (2, 3)
583-
assert na.dtype == np.float64
584582

585-
@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312,
586-
reason="conversion to numpy errors out unless python >= 3.12"
583+
@pytest.mark.skipif(np.__version__ < "2", reason="np.asarray has no copy kwarg")
584+
def test_array_conversion_copy():
585+
a = arange(5)
586+
na = np.asarray(a, copy=False)
587+
na[0] = 10
588+
assert a[0] == 10
589+
590+
a = arange(5)
591+
na = np.asarray(a, copy=True)
592+
na[0] = 10
593+
assert a[0] == 0
594+
595+
a = arange(5)
596+
with pytest.raises(ValueError):
597+
np.asarray(a, dtype=np.uint8, copy=False)
598+
599+
600+
@pytest.mark.skipif(
601+
sys.version_info < (3, 12), reason="Python <3.12 has no __buffer__ interface"
587602
)
588-
def test_array_conversion_2():
603+
def test_no_array_interface():
604+
"""When the __buffer__ interface is available, the __array__ interface is not."""
589605
a = ones((2, 3))
590-
with pytest.raises(TypeError):
591-
np.array(a)
606+
with pytest.raises(TypeError, match="not supported"):
607+
# Because NumPy prefers __buffer__ when available, we can't trigger this
608+
# exception from np.asarray().
609+
a.__array__()
592610

593611

594612
def test_allow_newaxis():

docs/changelog.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
# Changelog
22

3+
## 2.4.1 (unreleased)
4+
5+
### Major Changes
6+
7+
- The array object defines `__array__` again when `__buffer__` is not available.
8+
9+
- Support for Python versions 3.10 and 3.11 has been reinstated.
10+
11+
12+
### Minor Changes
13+
14+
- Fix bug in `np.asarray(a)`, where `a` is a `array_api_strict.Array` on a device
15+
that is not CPU_DEVICE, which caused NumPy to return a scalar object array wrapping
16+
around the whole Array instead of raising.
17+
18+
- Arithmetic operations no longer accept NumPy arrays.
19+
20+
- Disallow `__setitem__` for invalid dtype combinations (e.g. setting a float value
21+
into an integer array)
22+
23+
24+
### Contributors
25+
26+
The following users contributed to this release:
27+
28+
Evgeni Burovski,
29+
Guido Imperiale
30+
31+
332
## 2.4.0 (2025-06-16)
433

534
### Major Changes

0 commit comments

Comments
 (0)