Skip to content

Commit 71c5231

Browse files
committed
Add support for setting the api version to 2023.12
1 parent f49845a commit 71c5231

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

array_api_strict/_flags.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
supported_versions = (
2222
"2021.12",
2323
"2022.12",
24+
"2023.12",
2425
)
2526

2627
API_VERSION = default_version = "2022.12"
@@ -67,6 +68,9 @@ def set_array_api_strict_flags(
6768
Note that 2021.12 is supported, but currently gives the same thing as
6869
2022.12 (except that the fft extension will be disabled).
6970
71+
2023.12 support is preliminary. Some features in 2023.12 may still be
72+
missing, and it hasn't been fully tested.
73+
7074
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
7175
array-api-strict.
7276
@@ -123,6 +127,8 @@ def set_array_api_strict_flags(
123127
raise ValueError(f"Unsupported standard version {api_version!r}")
124128
if api_version == "2021.12":
125129
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
130+
if api_version == "2023.12":
131+
warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.")
126132
API_VERSION = api_version
127133
array_api_strict.__array_api_version__ = API_VERSION
128134

array_api_strict/tests/test_array_object.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,12 @@ def test_array_namespace():
410410
assert a.__array_namespace__(api_version="2022.12") is array_api_strict
411411
assert array_api_strict.__array_api_version__ == "2022.12"
412412

413+
assert a.__array_namespace__(api_version="2023.12") is array_api_strict
414+
assert array_api_strict.__array_api_version__ == "2023.12"
415+
413416
with pytest.warns(UserWarning):
414417
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
415418
assert array_api_strict.__array_api_version__ == "2021.12"
416419

417420
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
418-
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12"))
421+
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))

array_api_strict/tests/test_flags.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,19 @@ def test_flags():
5454
'data_dependent_shapes': True,
5555
'enabled_extensions': ('linalg',),
5656
}
57+
reset_array_api_strict_flags()
58+
59+
# 2023.12 should issue a warning
60+
with pytest.warns(UserWarning) as record:
61+
set_array_api_strict_flags(api_version='2023.12')
62+
assert len(record) == 1
63+
assert '2023.12' in str(record[0].message)
64+
flags = get_array_api_strict_flags()
65+
assert flags == {
66+
'api_version': '2023.12',
67+
'data_dependent_shapes': True,
68+
'enabled_extensions': ('linalg', 'fft'),
69+
}
5770

5871
# Test setting flags with invalid values
5972
pytest.raises(ValueError, lambda:

0 commit comments

Comments
 (0)