11
11
library will only support one particular configuration of these flags.
12
12
"""
13
13
14
+ import functools
14
15
import os
15
16
16
- supported_versions = [
17
+ supported_versions = (
17
18
"2021.12" ,
18
19
"2022.12" ,
19
- ]
20
+ )
20
21
21
- STANDARD_VERSION = "2022.12"
22
+ STANDARD_VERSION = default_version = "2022.12"
22
23
23
24
DATA_DEPENDENT_SHAPES = True
24
25
25
- all_extensions = [
26
+ all_extensions = (
26
27
"linalg" ,
27
28
"fft" ,
28
- ]
29
+ )
29
30
30
31
extension_versions = {
31
32
"linalg" : "2021.12" ,
32
33
"fft" : "2022.12" ,
33
34
}
34
35
35
- ENABLED_EXTENSIONS = [
36
+ ENABLED_EXTENSIONS = default_extensions = (
36
37
"linalg" ,
37
38
"fft" ,
38
- ]
39
+ )
40
+
41
+ # Public functions
39
42
40
43
def set_array_api_strict_flags (
41
44
* ,
@@ -136,8 +139,8 @@ def set_array_api_strict_flags(
136
139
# We have to do this separately or it won't get added as the docstring
137
140
set_array_api_strict_flags .__doc__ = set_array_api_strict_flags .__doc__ .format (
138
141
supported_versions = supported_versions ,
139
- default_version = STANDARD_VERSION ,
140
- default_extensions = ENABLED_EXTENSIONS ,
142
+ default_version = default_version ,
143
+ default_extensions = default_extensions ,
141
144
)
142
145
143
146
def get_array_api_strict_flags ():
@@ -160,7 +163,7 @@ def get_array_api_strict_flags():
160
163
>>> from array_api_strict import get_array_api_strict_flags
161
164
>>> flags = get_array_api_strict_flags()
162
165
>>> flags
163
- {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': [ 'linalg', 'fft'] }
166
+ {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ( 'linalg', 'fft') }
164
167
165
168
See Also
166
169
--------
@@ -181,6 +184,8 @@ def reset_array_api_strict_flags():
181
184
"""
182
185
Reset the array-api-strict flags to their default values.
183
186
187
+ This will also reset any flags that were set by environment variables.
188
+
184
189
.. note::
185
190
186
191
This function is **not** part of the array API standard. It only exists
@@ -201,9 +206,9 @@ def reset_array_api_strict_flags():
201
206
202
207
"""
203
208
global STANDARD_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
204
- STANDARD_VERSION = "2022.12"
209
+ STANDARD_VERSION = default_version
205
210
DATA_DEPENDENT_SHAPES = True
206
- ENABLED_EXTENSIONS = [ "linalg" , "fft" ]
211
+ ENABLED_EXTENSIONS = default_extensions
207
212
208
213
209
214
class ArrayApiStrictFlags :
@@ -241,18 +246,22 @@ def __enter__(self):
241
246
def __exit__ (self , exc_type , exc_value , traceback ):
242
247
set_array_api_strict_flags (** self .old_flags )
243
248
244
- # Set the flags from the environment variables
245
- if "ARRAY_API_STRICT_STANDARD_VERSION" in os .environ :
246
- set_array_api_strict_flags (
247
- standard_version = os .environ ["ARRAY_API_STRICT_STANDARD_VERSION" ]
248
- )
249
-
250
- if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
251
- set_array_api_strict_flags (
252
- data_dependent_shapes = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
253
- )
254
-
255
- if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os .environ :
256
- set_array_api_strict_flags (
257
- enabled_extensions = os .environ ["ARRAY_API_STRICT_ENABLED_EXTENSIONS" ].split ("," )
258
- )
249
+ # Private functions
250
+
251
+ def set_flags_from_environment ():
252
+ if "ARRAY_API_STRICT_STANDARD_VERSION" in os .environ :
253
+ set_array_api_strict_flags (
254
+ standard_version = os .environ ["ARRAY_API_STRICT_STANDARD_VERSION" ]
255
+ )
256
+
257
+ if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
258
+ set_array_api_strict_flags (
259
+ data_dependent_shapes = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
260
+ )
261
+
262
+ if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os .environ :
263
+ set_array_api_strict_flags (
264
+ enabled_extensions = os .environ ["ARRAY_API_STRICT_ENABLED_EXTENSIONS" ].split ("," )
265
+ )
266
+
267
+ set_flags_from_environment ()
0 commit comments