Skip to content

Commit 319799e

Browse files
committed
Rename the "standard_version" flag to "api_version"
This matches the name used in __array_namespace__
1 parent befd28c commit 319799e

File tree

3 files changed

+28
-28
lines changed

3 files changed

+28
-28
lines changed

array_api_strict/_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def __array_namespace__(
501501
{class}`array_api_strict.ArrayApiStrictFlags` context manager.
502502
503503
"""
504-
set_array_api_strict_flags(standard_version=api_version)
504+
set_array_api_strict_flags(api_version=api_version)
505505
if api_version == "2021.12":
506506
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
507507
import array_api_strict

array_api_strict/_flags.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"2022.12",
2020
)
2121

22-
STANDARD_VERSION = default_version = "2022.12"
22+
API_VERSION = default_version = "2022.12"
2323

2424
DATA_DEPENDENT_SHAPES = True
2525

@@ -42,7 +42,7 @@
4242

4343
def set_array_api_strict_flags(
4444
*,
45-
standard_version=None,
45+
api_version=None,
4646
data_dependent_shapes=None,
4747
enabled_extensions=None,
4848
):
@@ -57,7 +57,7 @@ def set_array_api_strict_flags(
5757
This function is **not** part of the array API standard. It only exists
5858
in array-api-strict.
5959
60-
- `standard_version`: The version of the standard to use. Supported
60+
- `api_version`: The version of the standard to use. Supported
6161
versions are: ``{supported_versions}``. The default version number is
6262
``{default_version!r}``.
6363
@@ -88,7 +88,7 @@ def set_array_api_strict_flags(
8888
The default values of the flags can also be changed by setting environment
8989
variables:
9090
91-
- ``ARRAY_API_STRICT_STANDARD_VERSION``: A string representing the version number.
91+
- ``ARRAY_API_STRICT_API_VERSION``: A string representing the version number.
9292
- ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False".
9393
- ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of
9494
extensions to enable.
@@ -98,7 +98,7 @@ def set_array_api_strict_flags(
9898
9999
>>> from array_api_strict import set_array_api_strict_flags
100100
>>> # Set the standard version to 2021.12
101-
>>> set_array_api_strict_flags(standard_version="2021.12")
101+
>>> set_array_api_strict_flags(api_version="2021.12")
102102
>>> # Disable data-dependent shapes
103103
>>> set_array_api_strict_flags(data_dependent_shapes=False)
104104
>>> # Enable only the linalg extension (disable the fft extension)
@@ -112,12 +112,12 @@ def set_array_api_strict_flags(
112112
ArrayApiStrictFlags: A context manager to temporarily set the flags.
113113
114114
"""
115-
global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
115+
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
116116

117-
if standard_version is not None:
118-
if standard_version not in supported_versions:
119-
raise ValueError(f"Unsupported standard version {standard_version!r}")
120-
STANDARD_VERSION = standard_version
117+
if api_version is not None:
118+
if api_version not in supported_versions:
119+
raise ValueError(f"Unsupported standard version {api_version!r}")
120+
API_VERSION = api_version
121121

122122
if data_dependent_shapes is not None:
123123
DATA_DEPENDENT_SHAPES = data_dependent_shapes
@@ -126,14 +126,14 @@ def set_array_api_strict_flags(
126126
for extension in enabled_extensions:
127127
if extension not in all_extensions:
128128
raise ValueError(f"Unsupported extension {extension}")
129-
if extension_versions[extension] > STANDARD_VERSION:
129+
if extension_versions[extension] > API_VERSION:
130130
raise ValueError(
131131
f"Extension {extension} requires standard version "
132132
f"{extension_versions[extension]} or later"
133133
)
134134
ENABLED_EXTENSIONS = tuple(enabled_extensions)
135135
else:
136-
ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= STANDARD_VERSION])
136+
ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION])
137137

138138
# We have to do this separately or it won't get added as the docstring
139139
set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format(
@@ -162,7 +162,7 @@ def get_array_api_strict_flags():
162162
>>> from array_api_strict import get_array_api_strict_flags
163163
>>> flags = get_array_api_strict_flags()
164164
>>> flags
165-
{'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
165+
{'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
166166
167167
See Also
168168
--------
@@ -173,7 +173,7 @@ def get_array_api_strict_flags():
173173
174174
"""
175175
return {
176-
"standard_version": STANDARD_VERSION,
176+
"api_version": API_VERSION,
177177
"data_dependent_shapes": DATA_DEPENDENT_SHAPES,
178178
"enabled_extensions": ENABLED_EXTENSIONS,
179179
}
@@ -204,8 +204,8 @@ def reset_array_api_strict_flags():
204204
ArrayApiStrictFlags: A context manager to temporarily set the flags.
205205
206206
"""
207-
global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
208-
STANDARD_VERSION = default_version
207+
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
208+
API_VERSION = default_version
209209
DATA_DEPENDENT_SHAPES = True
210210
ENABLED_EXTENSIONS = default_extensions
211211

@@ -230,10 +230,10 @@ class ArrayApiStrictFlags:
230230
reset_array_api_strict_flags
231231
232232
"""
233-
def __init__(self, *, standard_version=None, data_dependent_shapes=None,
233+
def __init__(self, *, api_version=None, data_dependent_shapes=None,
234234
enabled_extensions=None):
235235
self.kwargs = {
236-
"standard_version": standard_version,
236+
"api_version": api_version,
237237
"data_dependent_shapes": data_dependent_shapes,
238238
"enabled_extensions": enabled_extensions,
239239
}
@@ -248,9 +248,9 @@ def __exit__(self, exc_type, exc_value, traceback):
248248
# Private functions
249249

250250
def set_flags_from_environment():
251-
if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ:
251+
if "ARRAY_API_STRICT_API_VERSION" in os.environ:
252252
set_array_api_strict_flags(
253-
standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"]
253+
api_version=os.environ["ARRAY_API_STRICT_API_VERSION"]
254254
)
255255

256256
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ:

array_api_strict/tests/test_flags.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_flags():
1818
# Test defaults
1919
flags = get_array_api_strict_flags()
2020
assert flags == {
21-
'standard_version': '2022.12',
21+
'api_version': '2022.12',
2222
'data_dependent_shapes': True,
2323
'enabled_extensions': ('linalg', 'fft'),
2424
}
@@ -27,33 +27,33 @@ def test_flags():
2727
set_array_api_strict_flags(data_dependent_shapes=False)
2828
flags = get_array_api_strict_flags()
2929
assert flags == {
30-
'standard_version': '2022.12',
30+
'api_version': '2022.12',
3131
'data_dependent_shapes': False,
3232
'enabled_extensions': ('linalg', 'fft'),
3333
}
3434
set_array_api_strict_flags(enabled_extensions=('fft',))
3535
flags = get_array_api_strict_flags()
3636
assert flags == {
37-
'standard_version': '2022.12',
37+
'api_version': '2022.12',
3838
'data_dependent_shapes': False,
3939
'enabled_extensions': ('fft',),
4040
}
4141
# Make sure setting the version to 2021.12 disables fft
42-
set_array_api_strict_flags(standard_version='2021.12')
42+
set_array_api_strict_flags(api_version='2021.12')
4343
flags = get_array_api_strict_flags()
4444
assert flags == {
45-
'standard_version': '2021.12',
45+
'api_version': '2021.12',
4646
'data_dependent_shapes': False,
4747
'enabled_extensions': ('linalg',),
4848
}
4949

5050
# Test setting flags with invalid values
5151
pytest.raises(ValueError, lambda:
52-
set_array_api_strict_flags(standard_version='2020.12'))
52+
set_array_api_strict_flags(api_version='2020.12'))
5353
pytest.raises(ValueError, lambda: set_array_api_strict_flags(
5454
enabled_extensions=('linalg', 'fft', 'invalid')))
5555
pytest.raises(ValueError, lambda: set_array_api_strict_flags(
56-
standard_version='2021.12',
56+
api_version='2021.12',
5757
enabled_extensions=('linalg', 'fft')))
5858

5959

0 commit comments

Comments
 (0)