19
19
"2022.12" ,
20
20
)
21
21
22
- STANDARD_VERSION = default_version = "2022.12"
22
+ API_VERSION = default_version = "2022.12"
23
23
24
24
DATA_DEPENDENT_SHAPES = True
25
25
42
42
43
43
def set_array_api_strict_flags (
44
44
* ,
45
- standard_version = None ,
45
+ api_version = None ,
46
46
data_dependent_shapes = None ,
47
47
enabled_extensions = None ,
48
48
):
@@ -57,7 +57,7 @@ def set_array_api_strict_flags(
57
57
This function is **not** part of the array API standard. It only exists
58
58
in array-api-strict.
59
59
60
- - `standard_version `: The version of the standard to use. Supported
60
+ - `api_version `: The version of the standard to use. Supported
61
61
versions are: ``{supported_versions}``. The default version number is
62
62
``{default_version!r}``.
63
63
@@ -88,7 +88,7 @@ def set_array_api_strict_flags(
88
88
The default values of the flags can also be changed by setting environment
89
89
variables:
90
90
91
- - ``ARRAY_API_STRICT_STANDARD_VERSION ``: A string representing the version number.
91
+ - ``ARRAY_API_STRICT_API_VERSION ``: A string representing the version number.
92
92
- ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False".
93
93
- ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of
94
94
extensions to enable.
@@ -98,7 +98,7 @@ def set_array_api_strict_flags(
98
98
99
99
>>> from array_api_strict import set_array_api_strict_flags
100
100
>>> # Set the standard version to 2021.12
101
- >>> set_array_api_strict_flags(standard_version ="2021.12")
101
+ >>> set_array_api_strict_flags(api_version ="2021.12")
102
102
>>> # Disable data-dependent shapes
103
103
>>> set_array_api_strict_flags(data_dependent_shapes=False)
104
104
>>> # Enable only the linalg extension (disable the fft extension)
@@ -112,12 +112,12 @@ def set_array_api_strict_flags(
112
112
ArrayApiStrictFlags: A context manager to temporarily set the flags.
113
113
114
114
"""
115
- global STANDARD_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
115
+ global API_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
116
116
117
- if standard_version is not None :
118
- if standard_version not in supported_versions :
119
- raise ValueError (f"Unsupported standard version { standard_version !r} " )
120
- STANDARD_VERSION = standard_version
117
+ if api_version is not None :
118
+ if api_version not in supported_versions :
119
+ raise ValueError (f"Unsupported standard version { api_version !r} " )
120
+ API_VERSION = api_version
121
121
122
122
if data_dependent_shapes is not None :
123
123
DATA_DEPENDENT_SHAPES = data_dependent_shapes
@@ -126,14 +126,14 @@ def set_array_api_strict_flags(
126
126
for extension in enabled_extensions :
127
127
if extension not in all_extensions :
128
128
raise ValueError (f"Unsupported extension { extension } " )
129
- if extension_versions [extension ] > STANDARD_VERSION :
129
+ if extension_versions [extension ] > API_VERSION :
130
130
raise ValueError (
131
131
f"Extension { extension } requires standard version "
132
132
f"{ extension_versions [extension ]} or later"
133
133
)
134
134
ENABLED_EXTENSIONS = tuple (enabled_extensions )
135
135
else :
136
- ENABLED_EXTENSIONS = tuple ([ext for ext in all_extensions if extension_versions [ext ] <= STANDARD_VERSION ])
136
+ ENABLED_EXTENSIONS = tuple ([ext for ext in all_extensions if extension_versions [ext ] <= API_VERSION ])
137
137
138
138
# We have to do this separately or it won't get added as the docstring
139
139
set_array_api_strict_flags .__doc__ = set_array_api_strict_flags .__doc__ .format (
@@ -162,7 +162,7 @@ def get_array_api_strict_flags():
162
162
>>> from array_api_strict import get_array_api_strict_flags
163
163
>>> flags = get_array_api_strict_flags()
164
164
>>> flags
165
- {'standard_version ': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
165
+ {'api_version ': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
166
166
167
167
See Also
168
168
--------
@@ -173,7 +173,7 @@ def get_array_api_strict_flags():
173
173
174
174
"""
175
175
return {
176
- "standard_version " : STANDARD_VERSION ,
176
+ "api_version " : API_VERSION ,
177
177
"data_dependent_shapes" : DATA_DEPENDENT_SHAPES ,
178
178
"enabled_extensions" : ENABLED_EXTENSIONS ,
179
179
}
@@ -204,8 +204,8 @@ def reset_array_api_strict_flags():
204
204
ArrayApiStrictFlags: A context manager to temporarily set the flags.
205
205
206
206
"""
207
- global STANDARD_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
208
- STANDARD_VERSION = default_version
207
+ global API_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
208
+ API_VERSION = default_version
209
209
DATA_DEPENDENT_SHAPES = True
210
210
ENABLED_EXTENSIONS = default_extensions
211
211
@@ -230,10 +230,10 @@ class ArrayApiStrictFlags:
230
230
reset_array_api_strict_flags
231
231
232
232
"""
233
- def __init__ (self , * , standard_version = None , data_dependent_shapes = None ,
233
+ def __init__ (self , * , api_version = None , data_dependent_shapes = None ,
234
234
enabled_extensions = None ):
235
235
self .kwargs = {
236
- "standard_version " : standard_version ,
236
+ "api_version " : api_version ,
237
237
"data_dependent_shapes" : data_dependent_shapes ,
238
238
"enabled_extensions" : enabled_extensions ,
239
239
}
@@ -248,9 +248,9 @@ def __exit__(self, exc_type, exc_value, traceback):
248
248
# Private functions
249
249
250
250
def set_flags_from_environment ():
251
- if "ARRAY_API_STRICT_STANDARD_VERSION " in os .environ :
251
+ if "ARRAY_API_STRICT_API_VERSION " in os .environ :
252
252
set_array_api_strict_flags (
253
- standard_version = os .environ ["ARRAY_API_STRICT_STANDARD_VERSION " ]
253
+ api_version = os .environ ["ARRAY_API_STRICT_API_VERSION " ]
254
254
)
255
255
256
256
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
0 commit comments