Skip to content

Commit f49845a

Browse files
committed
Don't re-enable disabled extensions when setting the api version
1 parent e61b50d commit f49845a

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

array_api_strict/_flags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def set_array_api_strict_flags(
140140
)
141141
ENABLED_EXTENSIONS = tuple(enabled_extensions)
142142
else:
143-
ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION])
143+
ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION])
144144

145145
# We have to do this separately or it won't get added as the docstring
146146
set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format(

array_api_strict/tests/test_flags.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ def test_flags():
4242
assert flags == {
4343
'api_version': '2021.12',
4444
'data_dependent_shapes': False,
45+
'enabled_extensions': (),
46+
}
47+
reset_array_api_strict_flags()
48+
49+
with pytest.warns(UserWarning):
50+
set_array_api_strict_flags(api_version='2021.12')
51+
flags = get_array_api_strict_flags()
52+
assert flags == {
53+
'api_version': '2021.12',
54+
'data_dependent_shapes': True,
4555
'enabled_extensions': ('linalg',),
4656
}
4757

0 commit comments

Comments
 (0)