Skip to content

Commit f34576c

Browse files
committed
Some small code cleanups to the flags file
1 parent d8c3745 commit f34576c

File tree

1 file changed

+36
-27
lines changed

1 file changed

+36
-27
lines changed

array_api_strict/_flags.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,34 @@
1111
library will only support one particular configuration of these flags.
1212
"""
1313

14+
import functools
1415
import os
1516

16-
supported_versions = [
17+
supported_versions = (
1718
"2021.12",
1819
"2022.12",
19-
]
20+
)
2021

21-
STANDARD_VERSION = "2022.12"
22+
STANDARD_VERSION = default_version = "2022.12"
2223

2324
DATA_DEPENDENT_SHAPES = True
2425

25-
all_extensions = [
26+
all_extensions = (
2627
"linalg",
2728
"fft",
28-
]
29+
)
2930

3031
extension_versions = {
3132
"linalg": "2021.12",
3233
"fft": "2022.12",
3334
}
3435

35-
ENABLED_EXTENSIONS = [
36+
ENABLED_EXTENSIONS = default_extensions = (
3637
"linalg",
3738
"fft",
38-
]
39+
)
40+
41+
# Public functions
3942

4043
def set_array_api_strict_flags(
4144
*,
@@ -136,8 +139,8 @@ def set_array_api_strict_flags(
136139
# We have to do this separately or it won't get added as the docstring
137140
set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format(
138141
supported_versions=supported_versions,
139-
default_version=STANDARD_VERSION,
140-
default_extensions=ENABLED_EXTENSIONS,
142+
default_version=default_version,
143+
default_extensions=default_extensions,
141144
)
142145

143146
def get_array_api_strict_flags():
@@ -160,7 +163,7 @@ def get_array_api_strict_flags():
160163
>>> from array_api_strict import get_array_api_strict_flags
161164
>>> flags = get_array_api_strict_flags()
162165
>>> flags
163-
{'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ['linalg', 'fft']}
166+
{'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
164167
165168
See Also
166169
--------
@@ -181,6 +184,8 @@ def reset_array_api_strict_flags():
181184
"""
182185
Reset the array-api-strict flags to their default values.
183186
187+
This will also reset any flags that were set by environment variables.
188+
184189
.. note::
185190
186191
This function is **not** part of the array API standard. It only exists
@@ -201,9 +206,9 @@ def reset_array_api_strict_flags():
201206
202207
"""
203208
global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
204-
STANDARD_VERSION = "2022.12"
209+
STANDARD_VERSION = default_version
205210
DATA_DEPENDENT_SHAPES = True
206-
ENABLED_EXTENSIONS = ["linalg", "fft"]
211+
ENABLED_EXTENSIONS = default_extensions
207212

208213

209214
class ArrayApiStrictFlags:
@@ -241,18 +246,22 @@ def __enter__(self):
241246
def __exit__(self, exc_type, exc_value, traceback):
242247
set_array_api_strict_flags(**self.old_flags)
243248

244-
# Set the flags from the environment variables
245-
if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ:
246-
set_array_api_strict_flags(
247-
standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"]
248-
)
249-
250-
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ:
251-
set_array_api_strict_flags(
252-
data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true"
253-
)
254-
255-
if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ:
256-
set_array_api_strict_flags(
257-
enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",")
258-
)
249+
# Private functions
250+
251+
def set_flags_from_environment():
252+
if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ:
253+
set_array_api_strict_flags(
254+
standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"]
255+
)
256+
257+
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ:
258+
set_array_api_strict_flags(
259+
data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true"
260+
)
261+
262+
if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ:
263+
set_array_api_strict_flags(
264+
enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",")
265+
)
266+
267+
set_flags_from_environment()

0 commit comments

Comments
 (0)