Skip to content

Commit 30baeb7

Browse files
committed
Move warning about 2021.12 to set_array_api_strict_flags()
1 parent 0ba1267 commit 30baeb7

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

array_api_strict/_array_object.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import operator
1919
from enum import IntEnum
20-
import warnings
2120

2221
from ._creation_functions import asarray
2322
from ._dtypes import (
@@ -502,8 +501,6 @@ def __array_namespace__(
502501
503502
"""
504503
set_array_api_strict_flags(api_version=api_version)
505-
if api_version == "2021.12":
506-
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
507504
import array_api_strict
508505
return array_api_strict
509506

array_api_strict/_flags.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import functools
1515
import os
16+
import warnings
1617

1718
import array_api_strict
1819

@@ -62,6 +63,9 @@ def set_array_api_strict_flags(
6263
versions are: ``{supported_versions}``. The default version number is
6364
``{default_version!r}``.
6465
66+
Note that 2021.12 is supported, but currently gives the same thing as
67+
2022.12 (except that the fft extension will be disabled).
68+
6569
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
6670
array-api-strict.
6771
@@ -118,6 +122,8 @@ def set_array_api_strict_flags(
118122
if api_version is not None:
119123
if api_version not in supported_versions:
120124
raise ValueError(f"Unsupported standard version {api_version!r}")
125+
if api_version == "2021.12":
126+
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
121127
API_VERSION = api_version
122128
array_api_strict.__array_api_version__ = API_VERSION
123129

array_api_strict/tests/test_flags.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ def test_flags():
3232
'data_dependent_shapes': False,
3333
'enabled_extensions': ('fft',),
3434
}
35-
# Make sure setting the version to 2021.12 disables fft
36-
set_array_api_strict_flags(api_version='2021.12')
35+
# Make sure setting the version to 2021.12 disables fft and issues a
36+
# warning.
37+
with pytest.warns(UserWarning) as record:
38+
set_array_api_strict_flags(api_version='2021.12')
39+
assert len(record) == 1
40+
assert '2021.12' in str(record[0].message)
3741
flags = get_array_api_strict_flags()
3842
assert flags == {
3943
'api_version': '2021.12',
@@ -51,10 +55,11 @@ def test_flags():
5155
enabled_extensions=('linalg', 'fft')))
5256

5357
# Test resetting flags
54-
set_array_api_strict_flags(
55-
api_version='2021.12',
56-
data_dependent_shapes=False,
57-
enabled_extensions=())
58+
with pytest.warns(UserWarning):
59+
set_array_api_strict_flags(
60+
api_version='2021.12',
61+
data_dependent_shapes=False,
62+
enabled_extensions=())
5863
reset_array_api_strict_flags()
5964
flags = get_array_api_strict_flags()
6065
assert flags == {

0 commit comments

Comments
 (0)