29
29
_dtype_categories ,
30
30
)
31
31
32
- from typing import TYPE_CHECKING , Optional , Tuple , Union , Any
32
+ from typing import TYPE_CHECKING , Optional , Tuple , Union , Any , SupportsIndex
33
33
import types
34
34
35
35
if TYPE_CHECKING :
@@ -243,8 +243,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
243
243
244
244
# Note: A large fraction of allowed indices are disallowed here (see the
245
245
# docstring below)
246
- @staticmethod
247
- def _validate_index (key , shape ):
246
+ def _validate_index (self , key ):
248
247
"""
249
248
Validate an index according to the array API.
250
249
@@ -257,8 +256,7 @@ def _validate_index(key, shape):
257
256
https://data-apis.org/array-api/latest/API_specification/indexing.html
258
257
for the full list of required indexing behavior
259
258
260
- This function either raises IndexError if the index ``key`` is
261
- invalid, or a new key to be used in place of ``key`` in indexing. It
259
+ This function raises IndexError if the index ``key`` is invalid. It
262
260
only raises ``IndexError`` on indices that are not already rejected by
263
261
NumPy, as NumPy will already raise the appropriate error on such
264
262
indices. ``shape`` may be None, in which case, only cases that are
@@ -269,7 +267,7 @@ def _validate_index(key, shape):
269
267
270
268
- Indices to not include an implicit ellipsis at the end. That is,
271
269
every axis of an array must be explicitly indexed or an ellipsis
272
- included.
270
+ included. This behaviour is sometimes referred to as flat indexing.
273
271
274
272
- The start and stop of a slice may not be out of bounds. In
275
273
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
@@ -292,100 +290,122 @@ def _validate_index(key, shape):
292
290
``Array._new`` constructor, not this function.
293
291
294
292
"""
295
- if isinstance (key , slice ):
296
- if shape is None :
297
- return key
298
- if shape == ():
299
- return key
300
- if len (shape ) > 1 :
293
+ _key = key if isinstance (key , tuple ) else (key ,)
294
+ for i in _key :
295
+ if isinstance (i , bool ) or not (
296
+ isinstance (i , SupportsIndex ) # i.e. ints
297
+ or isinstance (i , slice )
298
+ or i == Ellipsis
299
+ or i is None
300
+ or isinstance (i , Array )
301
+ or isinstance (i , np .ndarray )
302
+ ):
301
303
raise IndexError (
302
- "Multidimensional arrays must include an index for every axis or use an ellipsis"
304
+ f"Single-axes index { i } has { type (i )= } , but only "
305
+ "integers, slices (:), ellipsis (...), newaxis (None), "
306
+ "zero-dimensional integer arrays and boolean arrays "
307
+ "are specified in the Array API."
303
308
)
304
- size = shape [0 ]
305
- # Ensure invalid slice entries are passed through.
306
- if key .start is not None :
307
- try :
308
- operator .index (key .start )
309
- except TypeError :
310
- return key
311
- if not (- size <= key .start <= size ):
312
- raise IndexError (
313
- "Slices with out-of-bounds start are not allowed in the array API namespace"
314
- )
315
- if key .stop is not None :
316
- try :
317
- operator .index (key .stop )
318
- except TypeError :
319
- return key
320
- step = 1 if key .step is None else key .step
321
- if (step > 0 and not (- size <= key .stop <= size )
322
- or step < 0 and not (- size - 1 <= key .stop <= max (0 , size - 1 ))):
323
- raise IndexError ("Slices with out-of-bounds stop are not allowed in the array API namespace" )
324
- return key
325
-
326
- elif isinstance (key , tuple ):
327
- key = tuple (Array ._validate_index (idx , None ) for idx in key )
328
-
329
- for idx in key :
330
- if (
331
- isinstance (idx , np .ndarray )
332
- and idx .dtype in _boolean_dtypes
333
- or isinstance (idx , (bool , np .bool_ ))
334
- ):
335
- if len (key ) == 1 :
336
- return key
337
- raise IndexError (
338
- "Boolean array indices combined with other indices are not allowed in the array API namespace"
339
- )
340
- if isinstance (idx , tuple ):
341
- raise IndexError (
342
- "Nested tuple indices are not allowed in the array API namespace"
343
- )
344
-
345
- if shape is None :
346
- return key
347
- n_ellipsis = key .count (...)
348
- if n_ellipsis > 1 :
349
- return key
350
- ellipsis_i = key .index (...) if n_ellipsis else len (key )
351
309
352
- for idx , size in list (zip (key [:ellipsis_i ], shape )) + list (
353
- zip (key [:ellipsis_i :- 1 ], shape [:ellipsis_i :- 1 ])
354
- ):
355
- Array ._validate_index (idx , (size ,))
356
- if n_ellipsis == 0 and len (key ) < len (shape ):
310
+ nonexpanding_key = []
311
+ single_axes = []
312
+ n_ellipsis = 0
313
+ key_has_mask = False
314
+ for i in _key :
315
+ if i is not None :
316
+ nonexpanding_key .append (i )
317
+ if isinstance (i , Array ) or isinstance (i , np .ndarray ):
318
+ if i .dtype in _boolean_dtypes :
319
+ key_has_mask = True
320
+ single_axes .append (i )
321
+ else :
322
+ # i must not be an array here, to avoid elementwise equals
323
+ if i == Ellipsis :
324
+ n_ellipsis += 1
325
+ else :
326
+ single_axes .append (i )
327
+
328
+ n_single_axes = len (single_axes )
329
+ if n_ellipsis > 1 :
330
+ return # handled by ndarray
331
+ elif n_ellipsis == 0 :
332
+ # Note boolean masks must be the sole index, which we check for
333
+ # later on.
334
+ if not key_has_mask and n_single_axes < self .ndim :
357
335
raise IndexError (
358
- "Multidimensional arrays must include an index for every axis or use an ellipsis"
336
+ f"{ self .ndim = } , but the multi-axes index only specifies "
337
+ f"{ n_single_axes } dimensions. If this was intentional, "
338
+ "add a trailing ellipsis (...) which expands into as many "
339
+ "slices (:) as necessary - this is what np.ndarray arrays "
340
+ "implicitly do, but such flat indexing behaviour is not "
341
+ "specified in the Array API."
359
342
)
360
- return key
361
- elif isinstance (key , bool ):
362
- return key
363
- elif isinstance (key , Array ):
364
- if key .dtype in _integer_dtypes :
365
- if key .ndim != 0 :
343
+
344
+ if n_ellipsis == 0 :
345
+ indexed_shape = self .shape
346
+ else :
347
+ ellipsis_start = None
348
+ for pos , i in enumerate (nonexpanding_key ):
349
+ if not (isinstance (i , Array ) or isinstance (i , np .ndarray )):
350
+ if i == Ellipsis :
351
+ ellipsis_start = pos
352
+ break
353
+ assert ellipsis_start is not None # sanity check
354
+ ellipsis_end = self .ndim - (n_single_axes - ellipsis_start )
355
+ indexed_shape = (
356
+ self .shape [:ellipsis_start ] + self .shape [ellipsis_end :]
357
+ )
358
+ for i , side in zip (single_axes , indexed_shape ):
359
+ if isinstance (i , slice ):
360
+ if side == 0 :
361
+ f_range = "0 (or None)"
362
+ else :
363
+ f_range = f"between -{ side } and { side - 1 } (or None)"
364
+ if i .start is not None :
365
+ try :
366
+ start = operator .index (i .start )
367
+ except TypeError :
368
+ pass # handled by ndarray
369
+ else :
370
+ if not (- side <= start <= side ):
371
+ raise IndexError (
372
+ f"Slice { i } contains { start = } , but should be "
373
+ f"{ f_range } for an axis of size { side } "
374
+ "(out-of-bounds starts are not specified in "
375
+ "the Array API)"
376
+ )
377
+ if i .stop is not None :
378
+ try :
379
+ stop = operator .index (i .stop )
380
+ except TypeError :
381
+ pass # handled by ndarray
382
+ else :
383
+ if not (- side <= stop <= side ):
384
+ raise IndexError (
385
+ f"Slice { i } contains { stop = } , but should be "
386
+ f"{ f_range } for an axis of size { side } "
387
+ "(out-of-bounds stops are not specified in "
388
+ "the Array API)"
389
+ )
390
+ elif isinstance (i , Array ):
391
+ if i .dtype in _boolean_dtypes and len (_key ) != 1 :
392
+ assert isinstance (key , tuple ) # sanity check
366
393
raise IndexError (
367
- "Non-zero dimensional integer array indices are not allowed in the array API namespace"
394
+ f"Single-axes index { i } is a boolean array and "
395
+ f"{ len (key )= } , but masking is only specified in the "
396
+ "Array API when the array is the sole index."
368
397
)
369
- return key ._array
370
- elif key is Ellipsis :
371
- return key
372
- elif key is None :
373
- raise IndexError (
374
- "newaxis indices are not allowed in the array API namespace"
375
- )
376
- try :
377
- key = operator .index (key )
378
- if shape is not None and len (shape ) > 1 :
398
+ elif i .dtype in _integer_dtypes and i .ndim != 0 :
399
+ raise IndexError (
400
+ f"Single-axes index { i } is a non-zero-dimensional "
401
+ "integer array, but advanced integer indexing is not "
402
+ "specified in the Array API."
403
+ )
404
+ elif isinstance (i , tuple ):
379
405
raise IndexError (
380
- "Multidimensional arrays must include an index for every axis or use an ellipsis"
406
+ f"Single-axes index { i } is a tuple, but nested tuple "
407
+ "indices are not specified in the Array API."
381
408
)
382
- return key
383
- except TypeError :
384
- # Note: This also omits boolean arrays that are not already in
385
- # Array() form, like a list of booleans.
386
- raise IndexError (
387
- "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace"
388
- )
389
409
390
410
# Everything below this line is required by the spec.
391
411
@@ -511,7 +531,10 @@ def __getitem__(
511
531
"""
512
532
# Note: Only indices required by the spec are allowed. See the
513
533
# docstring of _validate_index
514
- key = self ._validate_index (key , self .shape )
534
+ self ._validate_index (key )
535
+ if isinstance (key , Array ):
536
+ # Indexing self._array with array_api arrays can be erroneous
537
+ key = key ._array
515
538
res = self ._array .__getitem__ (key )
516
539
return self ._new (res )
517
540
@@ -698,7 +721,10 @@ def __setitem__(
698
721
"""
699
722
# Note: Only indices required by the spec are allowed. See the
700
723
# docstring of _validate_index
701
- key = self ._validate_index (key , self .shape )
724
+ self ._validate_index (key )
725
+ if isinstance (key , Array ):
726
+ # Indexing self._array with array_api arrays can be erroneous
727
+ key = key ._array
702
728
self ._array .__setitem__ (key , asarray (value )._array )
703
729
704
730
def __sub__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
0 commit comments