Skip to content

Commit 606fdb3

Browse files
authored
API: Allow newaxis indexing for array_api arrays (#21377)
* TST: Add test checking if newaxis indexing works for `array_api` Also removes previous check against newaxis indexing, which is now outdated * TST, BUG: Allow `None` in `array_api` indexing Introduces test for validating flat indexing when `None` is present * MAINT,DOC,TST: Rework of `_validate_index()` in `numpy.array_api` _validate_index() is now called as self._validate_index(shape), and does not return a key. This rework removes the recursive pattern used. Tests are introduced to cover some edge cases. Additionally, its internal docstring reflects new behaviour, and extends the flat indexing note. * MAINT: `advance` -> `advanced` (integer indexing) Co-authored-by: Aaron Meurer <asmeurer@gmail.com> * BUG: array_api arrays use internal arrays from array_api array keys When an array_api array is passed as the key for get/setitem, we access the key's internal np.ndarray array to be used as the key for the internal get/setitem operation. This behaviour was initially removed when `_validate_index()` was reworked. * MAINT: Better flat indexing error message for `array_api` arrays Also better semantics for its prior ellipsis count condition Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net> * MAINT: `array_api` arrays don't special case multi-ellipsis errors This gets handled by NumPy-proper. Co-authored-by: Aaron Meurer <asmeurer@gmail.com> Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net> Original NumPy Commit: befef7b26773eddd2b656a3ab87f504e6cc173db
1 parent d4fcd8c commit 606fdb3

File tree

2 files changed

+176
-99
lines changed

2 files changed

+176
-99
lines changed

array_api_strict/_array_object.py

Lines changed: 119 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
_dtype_categories,
3030
)
3131

32-
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
32+
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
3333
import types
3434

3535
if TYPE_CHECKING:
@@ -243,8 +243,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
243243

244244
# Note: A large fraction of allowed indices are disallowed here (see the
245245
# docstring below)
246-
@staticmethod
247-
def _validate_index(key, shape):
246+
def _validate_index(self, key):
248247
"""
249248
Validate an index according to the array API.
250249
@@ -257,8 +256,7 @@ def _validate_index(key, shape):
257256
https://data-apis.org/array-api/latest/API_specification/indexing.html
258257
for the full list of required indexing behavior
259258
260-
This function either raises IndexError if the index ``key`` is
261-
invalid, or a new key to be used in place of ``key`` in indexing. It
259+
This function raises IndexError if the index ``key`` is invalid. It
262260
only raises ``IndexError`` on indices that are not already rejected by
263261
NumPy, as NumPy will already raise the appropriate error on such
264262
indices. ``shape`` may be None, in which case, only cases that are
@@ -269,7 +267,7 @@ def _validate_index(key, shape):
269267
270268
- Indices to not include an implicit ellipsis at the end. That is,
271269
every axis of an array must be explicitly indexed or an ellipsis
272-
included.
270+
included. This behaviour is sometimes referred to as flat indexing.
273271
274272
- The start and stop of a slice may not be out of bounds. In
275273
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
@@ -292,100 +290,122 @@ def _validate_index(key, shape):
292290
``Array._new`` constructor, not this function.
293291
294292
"""
295-
if isinstance(key, slice):
296-
if shape is None:
297-
return key
298-
if shape == ():
299-
return key
300-
if len(shape) > 1:
293+
_key = key if isinstance(key, tuple) else (key,)
294+
for i in _key:
295+
if isinstance(i, bool) or not (
296+
isinstance(i, SupportsIndex) # i.e. ints
297+
or isinstance(i, slice)
298+
or i == Ellipsis
299+
or i is None
300+
or isinstance(i, Array)
301+
or isinstance(i, np.ndarray)
302+
):
301303
raise IndexError(
302-
"Multidimensional arrays must include an index for every axis or use an ellipsis"
304+
f"Single-axes index {i} has {type(i)=}, but only "
305+
"integers, slices (:), ellipsis (...), newaxis (None), "
306+
"zero-dimensional integer arrays and boolean arrays "
307+
"are specified in the Array API."
303308
)
304-
size = shape[0]
305-
# Ensure invalid slice entries are passed through.
306-
if key.start is not None:
307-
try:
308-
operator.index(key.start)
309-
except TypeError:
310-
return key
311-
if not (-size <= key.start <= size):
312-
raise IndexError(
313-
"Slices with out-of-bounds start are not allowed in the array API namespace"
314-
)
315-
if key.stop is not None:
316-
try:
317-
operator.index(key.stop)
318-
except TypeError:
319-
return key
320-
step = 1 if key.step is None else key.step
321-
if (step > 0 and not (-size <= key.stop <= size)
322-
or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))):
323-
raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace")
324-
return key
325-
326-
elif isinstance(key, tuple):
327-
key = tuple(Array._validate_index(idx, None) for idx in key)
328-
329-
for idx in key:
330-
if (
331-
isinstance(idx, np.ndarray)
332-
and idx.dtype in _boolean_dtypes
333-
or isinstance(idx, (bool, np.bool_))
334-
):
335-
if len(key) == 1:
336-
return key
337-
raise IndexError(
338-
"Boolean array indices combined with other indices are not allowed in the array API namespace"
339-
)
340-
if isinstance(idx, tuple):
341-
raise IndexError(
342-
"Nested tuple indices are not allowed in the array API namespace"
343-
)
344-
345-
if shape is None:
346-
return key
347-
n_ellipsis = key.count(...)
348-
if n_ellipsis > 1:
349-
return key
350-
ellipsis_i = key.index(...) if n_ellipsis else len(key)
351309

352-
for idx, size in list(zip(key[:ellipsis_i], shape)) + list(
353-
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
354-
):
355-
Array._validate_index(idx, (size,))
356-
if n_ellipsis == 0 and len(key) < len(shape):
310+
nonexpanding_key = []
311+
single_axes = []
312+
n_ellipsis = 0
313+
key_has_mask = False
314+
for i in _key:
315+
if i is not None:
316+
nonexpanding_key.append(i)
317+
if isinstance(i, Array) or isinstance(i, np.ndarray):
318+
if i.dtype in _boolean_dtypes:
319+
key_has_mask = True
320+
single_axes.append(i)
321+
else:
322+
# i must not be an array here, to avoid elementwise equals
323+
if i == Ellipsis:
324+
n_ellipsis += 1
325+
else:
326+
single_axes.append(i)
327+
328+
n_single_axes = len(single_axes)
329+
if n_ellipsis > 1:
330+
return # handled by ndarray
331+
elif n_ellipsis == 0:
332+
# Note boolean masks must be the sole index, which we check for
333+
# later on.
334+
if not key_has_mask and n_single_axes < self.ndim:
357335
raise IndexError(
358-
"Multidimensional arrays must include an index for every axis or use an ellipsis"
336+
f"{self.ndim=}, but the multi-axes index only specifies "
337+
f"{n_single_axes} dimensions. If this was intentional, "
338+
"add a trailing ellipsis (...) which expands into as many "
339+
"slices (:) as necessary - this is what np.ndarray arrays "
340+
"implicitly do, but such flat indexing behaviour is not "
341+
"specified in the Array API."
359342
)
360-
return key
361-
elif isinstance(key, bool):
362-
return key
363-
elif isinstance(key, Array):
364-
if key.dtype in _integer_dtypes:
365-
if key.ndim != 0:
343+
344+
if n_ellipsis == 0:
345+
indexed_shape = self.shape
346+
else:
347+
ellipsis_start = None
348+
for pos, i in enumerate(nonexpanding_key):
349+
if not (isinstance(i, Array) or isinstance(i, np.ndarray)):
350+
if i == Ellipsis:
351+
ellipsis_start = pos
352+
break
353+
assert ellipsis_start is not None # sanity check
354+
ellipsis_end = self.ndim - (n_single_axes - ellipsis_start)
355+
indexed_shape = (
356+
self.shape[:ellipsis_start] + self.shape[ellipsis_end:]
357+
)
358+
for i, side in zip(single_axes, indexed_shape):
359+
if isinstance(i, slice):
360+
if side == 0:
361+
f_range = "0 (or None)"
362+
else:
363+
f_range = f"between -{side} and {side - 1} (or None)"
364+
if i.start is not None:
365+
try:
366+
start = operator.index(i.start)
367+
except TypeError:
368+
pass # handled by ndarray
369+
else:
370+
if not (-side <= start <= side):
371+
raise IndexError(
372+
f"Slice {i} contains {start=}, but should be "
373+
f"{f_range} for an axis of size {side} "
374+
"(out-of-bounds starts are not specified in "
375+
"the Array API)"
376+
)
377+
if i.stop is not None:
378+
try:
379+
stop = operator.index(i.stop)
380+
except TypeError:
381+
pass # handled by ndarray
382+
else:
383+
if not (-side <= stop <= side):
384+
raise IndexError(
385+
f"Slice {i} contains {stop=}, but should be "
386+
f"{f_range} for an axis of size {side} "
387+
"(out-of-bounds stops are not specified in "
388+
"the Array API)"
389+
)
390+
elif isinstance(i, Array):
391+
if i.dtype in _boolean_dtypes and len(_key) != 1:
392+
assert isinstance(key, tuple) # sanity check
366393
raise IndexError(
367-
"Non-zero dimensional integer array indices are not allowed in the array API namespace"
394+
f"Single-axes index {i} is a boolean array and "
395+
f"{len(key)=}, but masking is only specified in the "
396+
"Array API when the array is the sole index."
368397
)
369-
return key._array
370-
elif key is Ellipsis:
371-
return key
372-
elif key is None:
373-
raise IndexError(
374-
"newaxis indices are not allowed in the array API namespace"
375-
)
376-
try:
377-
key = operator.index(key)
378-
if shape is not None and len(shape) > 1:
398+
elif i.dtype in _integer_dtypes and i.ndim != 0:
399+
raise IndexError(
400+
f"Single-axes index {i} is a non-zero-dimensional "
401+
"integer array, but advanced integer indexing is not "
402+
"specified in the Array API."
403+
)
404+
elif isinstance(i, tuple):
379405
raise IndexError(
380-
"Multidimensional arrays must include an index for every axis or use an ellipsis"
406+
f"Single-axes index {i} is a tuple, but nested tuple "
407+
"indices are not specified in the Array API."
381408
)
382-
return key
383-
except TypeError:
384-
# Note: This also omits boolean arrays that are not already in
385-
# Array() form, like a list of booleans.
386-
raise IndexError(
387-
"Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace"
388-
)
389409

390410
# Everything below this line is required by the spec.
391411

@@ -511,7 +531,10 @@ def __getitem__(
511531
"""
512532
# Note: Only indices required by the spec are allowed. See the
513533
# docstring of _validate_index
514-
key = self._validate_index(key, self.shape)
534+
self._validate_index(key)
535+
if isinstance(key, Array):
536+
# Indexing self._array with array_api arrays can be erroneous
537+
key = key._array
515538
res = self._array.__getitem__(key)
516539
return self._new(res)
517540

@@ -698,7 +721,10 @@ def __setitem__(
698721
"""
699722
# Note: Only indices required by the spec are allowed. See the
700723
# docstring of _validate_index
701-
key = self._validate_index(key, self.shape)
724+
self._validate_index(key)
725+
if isinstance(key, Array):
726+
# Indexing self._array with array_api arrays can be erroneous
727+
key = key._array
702728
self._array.__setitem__(key, asarray(value)._array)
703729

704730
def __sub__(self: Array, other: Union[int, float, Array], /) -> Array:

array_api_strict/tests/test_array_object.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from numpy.testing import assert_raises
44
import numpy as np
5+
import pytest
56

6-
from numpy. import ones, asarray, result_type, all, equal
7+
from numpy. import ones, asarray, reshape, result_type, all, equal
78
from numpy._array_object import Array
89
from numpy._dtypes import (
910
_all_dtypes,
@@ -17,6 +18,7 @@
1718
int32,
1819
int64,
1920
uint64,
21+
bool as bool_,
2022
)
2123

2224

@@ -70,11 +72,6 @@ def test_validate_index():
7072
assert_raises(IndexError, lambda: a[[0, 1]])
7173
assert_raises(IndexError, lambda: a[np.array([[0, 1]])])
7274

73-
# np.newaxis is not allowed
74-
assert_raises(IndexError, lambda: a[None])
75-
assert_raises(IndexError, lambda: a[None, ...])
76-
assert_raises(IndexError, lambda: a[..., None])
77-
7875
# Multiaxis indices must contain exactly as many indices as dimensions
7976
assert_raises(IndexError, lambda: a[()])
8077
assert_raises(IndexError, lambda: a[0,])
@@ -322,3 +319,57 @@ def test___array__():
322319
b = np.asarray(a, dtype=np.float64)
323320
assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
324321
assert b.dtype == np.float64
322+
323+
def test_allow_newaxis():
324+
a = ones(5)
325+
indexed_a = a[None, :]
326+
assert indexed_a.shape == (1, 5)
327+
328+
def test_disallow_flat_indexing_with_newaxis():
329+
a = ones((3, 3, 3))
330+
with pytest.raises(IndexError):
331+
a[None, 0, 0]
332+
333+
def test_disallow_mask_with_newaxis():
334+
a = ones((3, 3, 3))
335+
with pytest.raises(IndexError):
336+
a[None, asarray(True)]
337+
338+
@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
339+
@pytest.mark.parametrize("index", ["string", False, True])
340+
def test_error_on_invalid_index(shape, index):
341+
a = ones(shape)
342+
with pytest.raises(IndexError):
343+
a[index]
344+
345+
def test_mask_0d_array_without_errors():
346+
a = ones(())
347+
a[asarray(True)]
348+
349+
@pytest.mark.parametrize(
350+
"i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])]
351+
)
352+
def test_error_on_invalid_index_with_ellipsis(i):
353+
a = ones((3, 3, 3))
354+
with pytest.raises(IndexError):
355+
a[..., i]
356+
with pytest.raises(IndexError):
357+
a[i, ...]
358+
359+
def test_array_keys_use_private_array():
360+
"""
361+
Indexing operations convert array keys before indexing the internal array
362+
363+
Fails when array_api array keys are not converted into NumPy-proper arrays
364+
in __getitem__(). This is achieved by passing array_api arrays with 0-sized
365+
dimensions, which NumPy-proper treats erroneously - not sure why!
366+
367+
TODO: Find and use appropiate __setitem__() case.
368+
"""
369+
a = ones((0, 0), dtype=bool_)
370+
assert a[a].shape == (0,)
371+
372+
a = ones((0,), dtype=bool_)
373+
key = ones((0, 0), dtype=bool_)
374+
with pytest.raises(IndexError):
375+
a[key]

0 commit comments

Comments
 (0)