File tree Expand file tree Collapse file tree 3 files changed +17
-9
lines changed Expand file tree Collapse file tree 3 files changed +17
-9
lines changed Original file line number Diff line number Diff line change 17
17
18
18
import operator
19
19
from enum import IntEnum
20
- import warnings
21
20
22
21
from ._creation_functions import asarray
23
22
from ._dtypes import (
@@ -502,8 +501,6 @@ def __array_namespace__(
502
501
503
502
"""
504
503
set_array_api_strict_flags (api_version = api_version )
505
- if api_version == "2021.12" :
506
- warnings .warn ("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12" )
507
504
import array_api_strict
508
505
return array_api_strict
509
506
Original file line number Diff line number Diff line change 13
13
14
14
import functools
15
15
import os
16
+ import warnings
16
17
17
18
import array_api_strict
18
19
@@ -62,6 +63,9 @@ def set_array_api_strict_flags(
62
63
versions are: ``{supported_versions}``. The default version number is
63
64
``{default_version!r}``.
64
65
66
+ Note that 2021.12 is supported, but currently gives the same thing as
67
+ 2022.12 (except that the fft extension will be disabled).
68
+
65
69
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
66
70
array-api-strict.
67
71
@@ -118,6 +122,8 @@ def set_array_api_strict_flags(
118
122
if api_version is not None :
119
123
if api_version not in supported_versions :
120
124
raise ValueError (f"Unsupported standard version { api_version !r} " )
125
+ if api_version == "2021.12" :
126
+ warnings .warn ("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12" )
121
127
API_VERSION = api_version
122
128
array_api_strict .__array_api_version__ = API_VERSION
123
129
Original file line number Diff line number Diff line change @@ -32,8 +32,12 @@ def test_flags():
32
32
'data_dependent_shapes' : False ,
33
33
'enabled_extensions' : ('fft' ,),
34
34
}
35
- # Make sure setting the version to 2021.12 disables fft
36
- set_array_api_strict_flags (api_version = '2021.12' )
35
+ # Make sure setting the version to 2021.12 disables fft and issues a
36
+ # warning.
37
+ with pytest .warns (UserWarning ) as record :
38
+ set_array_api_strict_flags (api_version = '2021.12' )
39
+ assert len (record ) == 1
40
+ assert '2021.12' in str (record [0 ].message )
37
41
flags = get_array_api_strict_flags ()
38
42
assert flags == {
39
43
'api_version' : '2021.12' ,
@@ -51,10 +55,11 @@ def test_flags():
51
55
enabled_extensions = ('linalg' , 'fft' )))
52
56
53
57
# Test resetting flags
54
- set_array_api_strict_flags (
55
- api_version = '2021.12' ,
56
- data_dependent_shapes = False ,
57
- enabled_extensions = ())
58
+ with pytest .warns (UserWarning ):
59
+ set_array_api_strict_flags (
60
+ api_version = '2021.12' ,
61
+ data_dependent_shapes = False ,
62
+ enabled_extensions = ())
58
63
reset_array_api_strict_flags ()
59
64
flags = get_array_api_strict_flags ()
60
65
assert flags == {
You can’t perform that action at this time.
0 commit comments