26
26
27
27
API_VERSION = default_version = "2022.12"
28
28
29
+ BOOLEAN_INDEXING = True
30
+
29
31
DATA_DEPENDENT_SHAPES = True
30
32
31
33
all_extensions = (
47
49
def set_array_api_strict_flags (
48
50
* ,
49
51
api_version = None ,
52
+ boolean_indexing = None ,
50
53
data_dependent_shapes = None ,
51
54
enabled_extensions = None ,
52
55
):
@@ -71,6 +74,12 @@ def set_array_api_strict_flags(
71
74
2023.12 support is preliminary. Some features in 2023.12 may still be
72
75
missing, and it hasn't been fully tested.
73
76
77
+
78
+ - `boolean_indexing`: Whether indexing by a boolean array is supported.
79
+ Note that although boolean array indexing does result in data-dependent
80
+ shapes, this flag is independent of the `data_dependent_shapes` flag
81
+ (see below).
82
+
74
83
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
75
84
array-api-strict.
76
85
@@ -83,10 +92,12 @@ def set_array_api_strict_flags(
83
92
84
93
- `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`.
85
94
- `nonzero()`
86
- - Boolean array indexing
87
95
- `repeat()` when the `repeats` argument is an array (requires 2023.12
88
96
version of the standard)
89
97
98
+ Note that while boolean indexing is also data-dependent, it is
99
+ controlled by a separate `boolean_indexing` flag (see above).
100
+
90
101
See
91
102
https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html
92
103
for more details.
@@ -106,8 +117,8 @@ def set_array_api_strict_flags(
106
117
>>> # Set the standard version to 2021.12
107
118
>>> set_array_api_strict_flags(api_version="2021.12")
108
119
109
- >>> # Disable data-dependent shapes
110
- >>> set_array_api_strict_flags(data_dependent_shapes=False)
120
+ >>> # Disable data-dependent shapes and boolean indexing
121
+ >>> set_array_api_strict_flags(data_dependent_shapes=False, boolean_indexing=False )
111
122
112
123
>>> # Enable only the linalg extension (disable the fft extension)
113
124
>>> set_array_api_strict_flags(enabled_extensions=["linalg"])
@@ -120,7 +131,7 @@ def set_array_api_strict_flags(
120
131
ArrayAPIStrictFlags: A context manager to temporarily set the flags.
121
132
122
133
"""
123
- global API_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
134
+ global API_VERSION , BOOLEAN_INDEXING , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
124
135
125
136
if api_version is not None :
126
137
if api_version not in supported_versions :
@@ -132,6 +143,9 @@ def set_array_api_strict_flags(
132
143
API_VERSION = api_version
133
144
array_api_strict .__array_api_version__ = API_VERSION
134
145
146
+ if boolean_indexing is not None :
147
+ BOOLEAN_INDEXING = boolean_indexing
148
+
135
149
if data_dependent_shapes is not None :
136
150
DATA_DEPENDENT_SHAPES = data_dependent_shapes
137
151
@@ -175,7 +189,11 @@ def get_array_api_strict_flags():
175
189
>>> from array_api_strict import get_array_api_strict_flags
176
190
>>> flags = get_array_api_strict_flags()
177
191
>>> flags
178
- {'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
192
+ {'api_version': '2022.12',
193
+ 'boolean_indexing': True,
194
+ 'data_dependent_shapes': True,
195
+ 'enabled_extensions': ('linalg', 'fft')
196
+ }
179
197
180
198
See Also
181
199
--------
@@ -187,6 +205,7 @@ def get_array_api_strict_flags():
187
205
"""
188
206
return {
189
207
"api_version" : API_VERSION ,
208
+ "boolean_indexing" : BOOLEAN_INDEXING ,
190
209
"data_dependent_shapes" : DATA_DEPENDENT_SHAPES ,
191
210
"enabled_extensions" : ENABLED_EXTENSIONS ,
192
211
}
@@ -221,9 +240,10 @@ def reset_array_api_strict_flags():
221
240
ArrayAPIStrictFlags: A context manager to temporarily set the flags.
222
241
223
242
"""
224
- global API_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
243
+ global API_VERSION , BOOLEAN_INDEXING , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
225
244
API_VERSION = default_version
226
245
array_api_strict .__array_api_version__ = API_VERSION
246
+ BOOLEAN_INDEXING = True
227
247
DATA_DEPENDENT_SHAPES = True
228
248
ENABLED_EXTENSIONS = default_extensions
229
249
@@ -248,10 +268,11 @@ class ArrayAPIStrictFlags:
248
268
reset_array_api_strict_flags: Reset the flags to their default values.
249
269
250
270
"""
251
- def __init__ (self , * , api_version = None , data_dependent_shapes = None ,
252
- enabled_extensions = None ):
271
+ def __init__ (self , * , api_version = None , boolean_indexing = None ,
272
+ data_dependent_shapes = None , enabled_extensions = None ):
253
273
self .kwargs = {
254
274
"api_version" : api_version ,
275
+ "boolean_indexing" : boolean_indexing ,
255
276
"data_dependent_shapes" : data_dependent_shapes ,
256
277
"enabled_extensions" : enabled_extensions ,
257
278
}
@@ -271,6 +292,11 @@ def set_flags_from_environment():
271
292
api_version = os .environ ["ARRAY_API_STRICT_API_VERSION" ]
272
293
)
273
294
295
+ if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os .environ :
296
+ set_array_api_strict_flags (
297
+ boolean_indexing = os .environ ["ARRAY_API_STRICT_BOOLEAN_INDEXING" ].lower () == "true"
298
+ )
299
+
274
300
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
275
301
set_array_api_strict_flags (
276
302
data_dependent_shapes = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
0 commit comments