Skip to content

Commit f247130

Browse files
committed
Merge branch 'main' into 2023.12
2 parents a30536b + e5bebbe commit f247130

File tree

4 files changed

+57
-10
lines changed

4 files changed

+57
-10
lines changed

array_api_strict/_array_object.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ def _validate_index(self, key):
436436
f"{len(key)=}, but masking is only specified in the "
437437
"Array API when the array is the sole index."
438438
)
439-
if not get_array_api_strict_flags()['data_dependent_shapes']:
440-
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
439+
if not get_array_api_strict_flags()['boolean_indexing']:
440+
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict")
441441

442442
elif i.dtype in _integer_dtypes and i.ndim != 0:
443443
raise IndexError(

array_api_strict/_flags.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
API_VERSION = default_version = "2022.12"
2828

29+
BOOLEAN_INDEXING = True
30+
2931
DATA_DEPENDENT_SHAPES = True
3032

3133
all_extensions = (
@@ -47,6 +49,7 @@
4749
def set_array_api_strict_flags(
4850
*,
4951
api_version=None,
52+
boolean_indexing=None,
5053
data_dependent_shapes=None,
5154
enabled_extensions=None,
5255
):
@@ -71,6 +74,12 @@ def set_array_api_strict_flags(
7174
2023.12 support is preliminary. Some features in 2023.12 may still be
7275
missing, and it hasn't been fully tested.
7376
77+
78+
- `boolean_indexing`: Whether indexing by a boolean array is supported.
79+
Note that although boolean array indexing does result in data-dependent
80+
shapes, this flag is independent of the `data_dependent_shapes` flag
81+
(see below).
82+
7483
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
7584
array-api-strict.
7685
@@ -83,10 +92,12 @@ def set_array_api_strict_flags(
8392
8493
- `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`.
8594
- `nonzero()`
86-
- Boolean array indexing
8795
- `repeat()` when the `repeats` argument is an array (requires 2023.12
8896
version of the standard)
8997
98+
Note that while boolean indexing is also data-dependent, it is
99+
controlled by a separate `boolean_indexing` flag (see above).
100+
90101
See
91102
https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html
92103
for more details.
@@ -106,8 +117,8 @@ def set_array_api_strict_flags(
106117
>>> # Set the standard version to 2021.12
107118
>>> set_array_api_strict_flags(api_version="2021.12")
108119
109-
>>> # Disable data-dependent shapes
110-
>>> set_array_api_strict_flags(data_dependent_shapes=False)
120+
>>> # Disable data-dependent shapes and boolean indexing
121+
>>> set_array_api_strict_flags(data_dependent_shapes=False, boolean_indexing=False)
111122
112123
>>> # Enable only the linalg extension (disable the fft extension)
113124
>>> set_array_api_strict_flags(enabled_extensions=["linalg"])
@@ -120,7 +131,7 @@ def set_array_api_strict_flags(
120131
ArrayAPIStrictFlags: A context manager to temporarily set the flags.
121132
122133
"""
123-
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
134+
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
124135

125136
if api_version is not None:
126137
if api_version not in supported_versions:
@@ -132,6 +143,9 @@ def set_array_api_strict_flags(
132143
API_VERSION = api_version
133144
array_api_strict.__array_api_version__ = API_VERSION
134145

146+
if boolean_indexing is not None:
147+
BOOLEAN_INDEXING = boolean_indexing
148+
135149
if data_dependent_shapes is not None:
136150
DATA_DEPENDENT_SHAPES = data_dependent_shapes
137151

@@ -175,7 +189,11 @@ def get_array_api_strict_flags():
175189
>>> from array_api_strict import get_array_api_strict_flags
176190
>>> flags = get_array_api_strict_flags()
177191
>>> flags
178-
{'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
192+
{'api_version': '2022.12',
193+
'boolean_indexing': True,
194+
'data_dependent_shapes': True,
195+
'enabled_extensions': ('linalg', 'fft')
196+
}
179197
180198
See Also
181199
--------
@@ -187,6 +205,7 @@ def get_array_api_strict_flags():
187205
"""
188206
return {
189207
"api_version": API_VERSION,
208+
"boolean_indexing": BOOLEAN_INDEXING,
190209
"data_dependent_shapes": DATA_DEPENDENT_SHAPES,
191210
"enabled_extensions": ENABLED_EXTENSIONS,
192211
}
@@ -221,9 +240,10 @@ def reset_array_api_strict_flags():
221240
ArrayAPIStrictFlags: A context manager to temporarily set the flags.
222241
223242
"""
224-
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
243+
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
225244
API_VERSION = default_version
226245
array_api_strict.__array_api_version__ = API_VERSION
246+
BOOLEAN_INDEXING = True
227247
DATA_DEPENDENT_SHAPES = True
228248
ENABLED_EXTENSIONS = default_extensions
229249

@@ -248,10 +268,11 @@ class ArrayAPIStrictFlags:
248268
reset_array_api_strict_flags: Reset the flags to their default values.
249269
250270
"""
251-
def __init__(self, *, api_version=None, data_dependent_shapes=None,
252-
enabled_extensions=None):
271+
def __init__(self, *, api_version=None, boolean_indexing=None,
272+
data_dependent_shapes=None, enabled_extensions=None):
253273
self.kwargs = {
254274
"api_version": api_version,
275+
"boolean_indexing": boolean_indexing,
255276
"data_dependent_shapes": data_dependent_shapes,
256277
"enabled_extensions": enabled_extensions,
257278
}
@@ -271,6 +292,11 @@ def set_flags_from_environment():
271292
api_version=os.environ["ARRAY_API_STRICT_API_VERSION"]
272293
)
273294

295+
if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os.environ:
296+
set_array_api_strict_flags(
297+
boolean_indexing=os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true"
298+
)
299+
274300
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ:
275301
set_array_api_strict_flags(
276302
data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true"

array_api_strict/tests/test_flags.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def test_flags():
1313
flags = get_array_api_strict_flags()
1414
assert flags == {
1515
'api_version': '2022.12',
16+
'boolean_indexing': True,
1617
'data_dependent_shapes': True,
1718
'enabled_extensions': ('linalg', 'fft'),
1819
}
@@ -22,13 +23,15 @@ def test_flags():
2223
flags = get_array_api_strict_flags()
2324
assert flags == {
2425
'api_version': '2022.12',
26+
'boolean_indexing': True,
2527
'data_dependent_shapes': False,
2628
'enabled_extensions': ('linalg', 'fft'),
2729
}
2830
set_array_api_strict_flags(enabled_extensions=('fft',))
2931
flags = get_array_api_strict_flags()
3032
assert flags == {
3133
'api_version': '2022.12',
34+
'boolean_indexing': True,
3235
'data_dependent_shapes': False,
3336
'enabled_extensions': ('fft',),
3437
}
@@ -41,6 +44,7 @@ def test_flags():
4144
flags = get_array_api_strict_flags()
4245
assert flags == {
4346
'api_version': '2021.12',
47+
'boolean_indexing': True,
4448
'data_dependent_shapes': False,
4549
'enabled_extensions': (),
4650
}
@@ -82,12 +86,14 @@ def test_flags():
8286
with pytest.warns(UserWarning):
8387
set_array_api_strict_flags(
8488
api_version='2021.12',
89+
boolean_indexing=False,
8590
data_dependent_shapes=False,
8691
enabled_extensions=())
8792
reset_array_api_strict_flags()
8893
flags = get_array_api_strict_flags()
8994
assert flags == {
9095
'api_version': '2022.12',
96+
'boolean_indexing': True,
9197
'data_dependent_shapes': True,
9298
'enabled_extensions': ('linalg', 'fft'),
9399
}
@@ -126,6 +132,17 @@ def test_data_dependent_shapes():
126132
pytest.raises(RuntimeError, lambda: unique_inverse(a))
127133
pytest.raises(RuntimeError, lambda: unique_values(a))
128134
pytest.raises(RuntimeError, lambda: nonzero(a))
135+
a[mask] # No error (boolean indexing is a separate flag)
136+
137+
def test_boolean_indexing():
138+
a = asarray([0, 0, 1, 2, 2])
139+
mask = asarray([True, False, True, False, True])
140+
141+
# Should not error
142+
a[mask]
143+
144+
set_array_api_strict_flags(boolean_indexing=False)
145+
129146
pytest.raises(RuntimeError, lambda: a[mask])
130147
pytest.raises(RuntimeError, lambda: repeat(a, repeats))
131148
repeat(a, 2) # Should never error

docs/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ used by array-api-strict initially. They will not change the defaults used by
3030

3131
A string representing the version number.
3232

33+
.. envvar:: ARRAY_API_STRICT_BOOLEAN_INDEXING
34+
35+
"True" or "False" to enable or disable boolean indexing.
36+
3337
.. envvar:: ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES
3438

3539
"True" or "False" to enable or disable data dependent shapes.

0 commit comments

Comments
 (0)