Skip to content

Commit 0ba1267

Browse files
committed
Set __array_api_version__ with the api_version flag
1 parent 0d758eb commit 0ba1267

File tree

5 files changed

+26
-3
lines changed

5 files changed

+26
-3
lines changed

array_api_strict/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
1717
"""
1818

19-
__array_api_version__ = "2022.12"
19+
# Warning: __array_api_version__ could change globally with
20+
# set_array_api_strict_flags(). This should always be accessed as an
21+
# attribute, like xp.__array_api_version__, or using
22+
# array_api_strict.get_array_api_strict_flags()['api_version'].
23+
from ._flags import API_VERSION as __array_api_version__
2024

2125
__all__ = ["__array_api_version__"]
2226

array_api_strict/_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def __array_namespace__(
497497
the API version for the array_api_strict module globally. This can
498498
also be achieved with the
499499
{func}`array_api_strict.set_array_api_strict_flags` function. If you
500-
want some way to only set the version locally, use the
500+
want to only set the version locally, use the
501501
{class}`array_api_strict.ArrayApiStrictFlags` context manager.
502502
503503
"""

array_api_strict/_flags.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import functools
1515
import os
1616

17+
import array_api_strict
18+
1719
supported_versions = (
1820
"2021.12",
1921
"2022.12",
@@ -37,7 +39,6 @@
3739
"linalg",
3840
"fft",
3941
)
40-
4142
# Public functions
4243

4344
def set_array_api_strict_flags(
@@ -118,6 +119,7 @@ def set_array_api_strict_flags(
118119
if api_version not in supported_versions:
119120
raise ValueError(f"Unsupported standard version {api_version!r}")
120121
API_VERSION = api_version
122+
array_api_strict.__array_api_version__ = API_VERSION
121123

122124
if data_dependent_shapes is not None:
123125
DATA_DEPENDENT_SHAPES = data_dependent_shapes
@@ -206,6 +208,7 @@ def reset_array_api_strict_flags():
206208
"""
207209
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
208210
API_VERSION = default_version
211+
array_api_strict.__array_api_version__ = API_VERSION
209212
DATA_DEPENDENT_SHAPES = True
210213
ENABLED_EXTENSIONS = default_extensions
211214

array_api_strict/tests/test_array_object.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,17 @@ def test_array_keys_use_private_array():
402402
def test_array_namespace():
403403
a = ones((3, 3))
404404
assert a.__array_namespace__() == array_api_strict
405+
assert array_api_strict.__array_api_version__ == "2022.12"
406+
405407
assert a.__array_namespace__(api_version=None) is array_api_strict
408+
assert array_api_strict.__array_api_version__ == "2022.12"
409+
406410
assert a.__array_namespace__(api_version="2022.12") is array_api_strict
411+
assert array_api_strict.__array_api_version__ == "2022.12"
412+
407413
with pytest.warns(UserWarning):
408414
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
415+
assert array_api_strict.__array_api_version__ == "2021.12"
416+
409417
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
410418
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12"))

array_api_strict/tests/test_flags.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def test_flags():
6363
'enabled_extensions': ('linalg', 'fft'),
6464
}
6565

66+
def test_api_version():
67+
# Test defaults
68+
assert xp.__array_api_version__ == '2022.12'
69+
70+
# Test setting the version
71+
set_array_api_strict_flags(api_version='2021.12')
72+
assert xp.__array_api_version__ == '2021.12'
73+
6674
def test_data_dependent_shapes():
6775
a = asarray([0, 0, 1, 2, 2])
6876
mask = asarray([True, False, True, False, True])

0 commit comments

Comments
 (0)