Skip to content

Commit 6a20e91

Browse files
committed
Add functionality for the data_dependent_shapes flag
1 parent f34576c commit 6a20e91

File tree

4 files changed

+28
-7
lines changed

4 files changed

+28
-7
lines changed

array_api_strict/_array_object.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
_result_type,
3333
_dtype_categories,
3434
)
35+
from ._flags import get_array_api_strict_flags
3536

3637
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
3738
import types
@@ -427,13 +428,17 @@ def _validate_index(self, key):
427428
"the Array API)"
428429
)
429430
elif isinstance(i, Array):
430-
if i.dtype in _boolean_dtypes and len(_key) != 1:
431-
assert isinstance(key, tuple) # sanity check
432-
raise IndexError(
433-
f"Single-axes index {i} is a boolean array and "
434-
f"{len(key)=}, but masking is only specified in the "
435-
"Array API when the array is the sole index."
436-
)
431+
if i.dtype in _boolean_dtypes:
432+
if len(_key) != 1:
433+
assert isinstance(key, tuple) # sanity check
434+
raise IndexError(
435+
f"Single-axes index {i} is a boolean array and "
436+
f"{len(key)=}, but masking is only specified in the "
437+
"Array API when the array is the sole index."
438+
)
439+
if not get_array_api_strict_flags()['data_dependent_shapes']:
440+
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
441+
437442
elif i.dtype in _integer_dtypes and i.ndim != 0:
438443
raise IndexError(
439444
f"Single-axes index {i} is a non-zero-dimensional "

array_api_strict/_flags.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,11 @@ def set_flags_from_environment():
265265
)
266266

267267
set_flags_from_environment()
268+
269+
def requires_data_dependent_shapes(func):
270+
@functools.wraps(func)
271+
def wrapper(*args, **kwargs):
272+
if not DATA_DEPENDENT_SHAPES:
273+
raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
274+
return func(*args, **kwargs)
275+
return wrapper

array_api_strict/_searching_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._array_object import Array
44
from ._dtypes import _result_type, _real_numeric_dtypes
5+
from ._flags import requires_data_dependent_shapes
56

67
from typing import Optional, Tuple
78

@@ -30,6 +31,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
3031
return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)))
3132

3233

34+
@requires_data_dependent_shapes
3335
def nonzero(x: Array, /) -> Tuple[Array, ...]:
3436
"""
3537
Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`.

array_api_strict/_set_functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from ._array_object import Array
44

5+
from ._flags import requires_data_dependent_shapes
6+
57
from typing import NamedTuple
68

79
import numpy as np
@@ -35,6 +37,7 @@ class UniqueInverseResult(NamedTuple):
3537
inverse_indices: Array
3638

3739

40+
@requires_data_dependent_shapes
3841
def unique_all(x: Array, /) -> UniqueAllResult:
3942
"""
4043
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
@@ -59,6 +62,7 @@ def unique_all(x: Array, /) -> UniqueAllResult:
5962
)
6063

6164

65+
@requires_data_dependent_shapes
6266
def unique_counts(x: Array, /) -> UniqueCountsResult:
6367
res = np.unique(
6468
x._array,
@@ -71,6 +75,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult:
7175
return UniqueCountsResult(*[Array._new(i) for i in res])
7276

7377

78+
@requires_data_dependent_shapes
7479
def unique_inverse(x: Array, /) -> UniqueInverseResult:
7580
"""
7681
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
@@ -90,6 +95,7 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
9095
return UniqueInverseResult(Array._new(values), Array._new(inverse_indices))
9196

9297

98+
@requires_data_dependent_shapes
9399
def unique_values(x: Array, /) -> Array:
94100
"""
95101
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.

0 commit comments

Comments
 (0)