File tree Expand file tree Collapse file tree 3 files changed +23
-1
lines changed Expand file tree Collapse file tree 3 files changed +23
-1
lines changed Original file line number Diff line number Diff line change 21
21
supported_versions = (
22
22
"2021.12" ,
23
23
"2022.12" ,
24
+ "2023.12" ,
24
25
)
25
26
26
27
API_VERSION = default_version = "2022.12"
@@ -67,6 +68,9 @@ def set_array_api_strict_flags(
67
68
Note that 2021.12 is supported, but currently gives the same thing as
68
69
2022.12 (except that the fft extension will be disabled).
69
70
71
+ 2023.12 support is preliminary. Some features in 2023.12 may still be
72
+ missing, and it hasn't been fully tested.
73
+
70
74
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
71
75
array-api-strict.
72
76
@@ -123,6 +127,8 @@ def set_array_api_strict_flags(
123
127
raise ValueError (f"Unsupported standard version { api_version !r} " )
124
128
if api_version == "2021.12" :
125
129
warnings .warn ("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12" )
130
+ if api_version == "2023.12" :
131
+ warnings .warn ("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested." )
126
132
API_VERSION = api_version
127
133
array_api_strict .__array_api_version__ = API_VERSION
128
134
Original file line number Diff line number Diff line change @@ -410,9 +410,12 @@ def test_array_namespace():
410
410
assert a .__array_namespace__ (api_version = "2022.12" ) is array_api_strict
411
411
assert array_api_strict .__array_api_version__ == "2022.12"
412
412
413
+ assert a .__array_namespace__ (api_version = "2023.12" ) is array_api_strict
414
+ assert array_api_strict .__array_api_version__ == "2023.12"
415
+
413
416
with pytest .warns (UserWarning ):
414
417
assert a .__array_namespace__ (api_version = "2021.12" ) is array_api_strict
415
418
assert array_api_strict .__array_api_version__ == "2021.12"
416
419
417
420
pytest .raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
418
- pytest .raises (ValueError , lambda : a .__array_namespace__ (api_version = "2023 .12" ))
421
+ pytest .raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024 .12" ))
Original file line number Diff line number Diff line change @@ -54,6 +54,19 @@ def test_flags():
54
54
'data_dependent_shapes' : True ,
55
55
'enabled_extensions' : ('linalg' ,),
56
56
}
57
+ reset_array_api_strict_flags ()
58
+
59
+ # 2023.12 should issue a warning
60
+ with pytest .warns (UserWarning ) as record :
61
+ set_array_api_strict_flags (api_version = '2023.12' )
62
+ assert len (record ) == 1
63
+ assert '2023.12' in str (record [0 ].message )
64
+ flags = get_array_api_strict_flags ()
65
+ assert flags == {
66
+ 'api_version' : '2023.12' ,
67
+ 'data_dependent_shapes' : True ,
68
+ 'enabled_extensions' : ('linalg' , 'fft' ),
69
+ }
57
70
58
71
# Test setting flags with invalid values
59
72
pytest .raises (ValueError , lambda :
You can’t perform that action at this time.
0 commit comments