Skip to content

Commit 2cd4c3f

Browse files
committed
Set up basic structure for array-api-strict flags
Flags are global variables that set array-api-strict in a specific mode. Currently support flags change the support array API standard version, enable or disable data-dependent shapes, and enable or disable optional extensions. This commit only sets up the structure for setting and getting these flags.
1 parent 984901f commit 2cd4c3f

File tree

2 files changed

+267
-0
lines changed

2 files changed

+267
-0
lines changed

array_api_strict/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,17 @@
284284

285285
__all__ += ["all", "any"]
286286

287+
# Helper functions that are not part of the standard
288+
289+
from ._flags import (
290+
set_array_api_strict_flags,
291+
get_array_api_strict_flags,
292+
reset_array_api_strict_flags,
293+
ArrayApiStrictFlags,
294+
)
295+
296+
__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayApiStrictFlags']
297+
287298
from . import _version
288299
__version__ = _version.get_versions()['version']
289300
del _version

array_api_strict/_flags.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
"""
2+
This file defines flags for that allow array-api-strict to be used in
3+
different "modes". These modes include
4+
5+
- Changing to different supported versions of the standard.
6+
- Enabling or disabling different optional behaviors (such as data-dependent
7+
shapes).
8+
- Enabling or disabling different optional extensions.
9+
10+
Nothing in this file is part of the standard itself. A typical array API
11+
library will only support one particular configuration of these flags.
12+
"""
13+
14+
import os
15+
16+
supported_versions = [
17+
"2021.12",
18+
"2022.12",
19+
]
20+
21+
STANDARD_VERSION = "2022.12"
22+
23+
DATA_DEPENDENT_SHAPES = True
24+
25+
all_extensions = [
26+
"linalg",
27+
"fft",
28+
]
29+
30+
extension_versions = {
31+
"linalg": "2021.12",
32+
"fft": "2022.12",
33+
}
34+
35+
ENABLED_EXTENSIONS = [
36+
"linalg",
37+
"fft",
38+
]
39+
40+
def set_array_api_strict_flags(
41+
*,
42+
standard_version=None,
43+
data_dependent_shapes=None,
44+
enabled_extensions=None,
45+
):
46+
"""
47+
Set the array-api-strict flags to the specified values.
48+
49+
Flags are global variables that enable or disable array-api-strict
50+
behaviors.
51+
52+
.. note::
53+
54+
This function is **not** part of the array API standard. It only exists
55+
in array-api-strict.
56+
57+
- `standard_version`: The version of the standard to use. Supported
58+
versions are: ``{supported_versions}``. The default version number is
59+
``{default_version!r}``.
60+
61+
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
62+
array-api-strict. This flag is enabled by default. Array libraries that
63+
use computation graphs may not be able to support functions whose output
64+
shapes depend on the input data.
65+
66+
This flag is enabled by default. Array libraries that use computation graphs may not be able to support
67+
functions whose output shapes depend on the input data.
68+
69+
The functions that make use of data-dependent shapes, and are therefore
70+
disabled by setting this flag to False are
71+
72+
- `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`.
73+
- `nonzero`
74+
- Boolean array indexing
75+
- `repeat` when the `repeats` argument is an array (requires 2023.12
76+
version of the standard)
77+
78+
See
79+
https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html
80+
for more details.
81+
82+
- `enabled_extensions`: A list of extensions that are enabled in
83+
array-api-strict. The default is ``{default_extensions}``. Note that
84+
some extensions require a minimum version of the standard.
85+
86+
The default values of the flags can also be changed by setting environment
87+
variables:
88+
89+
- ``ARRAY_API_STRICT_STANDARD_VERSION``: A string representing the version number.
90+
- ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False".
91+
- ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of
92+
extensions to enable.
93+
94+
Examples
95+
--------
96+
97+
>>> from array_api_strict import set_array_api_strict_flags
98+
>>> # Set the standard version to 2021.12
99+
>>> set_array_api_strict_flags(standard_version="2021.12")
100+
>>> # Disable data-dependent shapes
101+
>>> set_array_api_strict_flags(data_dependent_shapes=False)
102+
>>> # Enable only the linalg extension (disable the fft extension)
103+
>>> set_array_api_strict_flags(enabled_extensions=["linalg"])
104+
105+
See Also
106+
--------
107+
108+
get_array_api_strict_flags
109+
reset_array_api_strict_flags
110+
ArrayApiStrictFlags: A context manager to temporarily set the flags.
111+
112+
"""
113+
global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
114+
115+
if standard_version is not None:
116+
if standard_version not in supported_versions:
117+
raise ValueError(f"Unsupported standard version {standard_version}")
118+
STANDARD_VERSION = standard_version
119+
120+
if data_dependent_shapes is not None:
121+
DATA_DEPENDENT_SHAPES = data_dependent_shapes
122+
123+
if enabled_extensions is not None:
124+
for extension in enabled_extensions:
125+
if extension not in all_extensions:
126+
raise ValueError(f"Unsupported extension {extension}")
127+
if extension_versions[extension] > STANDARD_VERSION:
128+
raise ValueError(
129+
f"Extension {extension} requires standard version "
130+
f"{extension_versions[extension]} or later"
131+
)
132+
ENABLED_EXTENSIONS = enabled_extensions
133+
134+
# We have to do this separately or it won't get added as the docstring
135+
set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format(
136+
supported_versions=supported_versions,
137+
default_version=STANDARD_VERSION,
138+
default_extensions=ENABLED_EXTENSIONS,
139+
)
140+
141+
def get_array_api_strict_flags():
142+
"""
143+
Get the current array-api-strict flags.
144+
145+
.. note::
146+
147+
This function is **not** part of the array API standard. It only exists
148+
in array-api-strict.
149+
150+
Returns
151+
-------
152+
dict
153+
A dictionary containing the current array-api-strict flags.
154+
155+
Examples
156+
--------
157+
158+
>>> from array_api_strict import get_array_api_strict_flags
159+
>>> flags = get_array_api_strict_flags()
160+
>>> flags
161+
{'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ['linalg', 'fft']}
162+
163+
See Also
164+
--------
165+
166+
set_array_api_strict_flags
167+
reset_array_api_strict_flags
168+
ArrayApiStrictFlags: A context manager to temporarily set the flags.
169+
170+
"""
171+
return {
172+
"standard_version": STANDARD_VERSION,
173+
"data_dependent_shapes": DATA_DEPENDENT_SHAPES,
174+
"enabled_extensions": ENABLED_EXTENSIONS,
175+
}
176+
177+
178+
def reset_array_api_strict_flags():
179+
"""
180+
Reset the array-api-strict flags to their default values.
181+
182+
.. note::
183+
184+
This function is **not** part of the array API standard. It only exists
185+
in array-api-strict.
186+
187+
Examples
188+
--------
189+
190+
>>> from array_api_strict import reset_array_api_strict_flags
191+
>>> reset_array_api_strict_flags()
192+
193+
See Also
194+
--------
195+
196+
set_array_api_strict_flags
197+
get_array_api_strict_flags
198+
ArrayApiStrictFlags: A context manager to temporarily set the flags.
199+
200+
"""
201+
global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
202+
STANDARD_VERSION = "2022.12"
203+
DATA_DEPENDENT_SHAPES = True
204+
ENABLED_EXTENSIONS = ["linalg", "fft"]
205+
206+
207+
class ArrayApiStrictFlags:
208+
"""
209+
A context manager to temporarily set the array-api-strict flags.
210+
211+
.. note::
212+
213+
This class is **not** part of the array API standard. It only exists
214+
in array-api-strict.
215+
216+
See :func:`~.array_api_strict.set_array_api_strict_flags` for a
217+
description of the available flags.
218+
219+
See Also
220+
--------
221+
222+
set_array_api_strict_flags
223+
get_array_api_strict_flags
224+
reset_array_api_strict_flags
225+
226+
"""
227+
def __init__(self, *, standard_version=None, data_dependent_shapes=None,
228+
enabled_extensions=None):
229+
self.kwargs = {
230+
"standard_version": standard_version,
231+
"data_dependent_shapes": data_dependent_shapes,
232+
"enabled_extensions": enabled_extensions,
233+
}
234+
self.old_flags = get_array_api_strict_flags()
235+
236+
def __enter__(self):
237+
set_array_api_strict_flags(**self.kwargs)
238+
239+
def __exit__(self, exc_type, exc_value, traceback):
240+
set_array_api_strict_flags(**self.old_flags)
241+
242+
# Set the flags from the environment variables
243+
if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ:
244+
set_array_api_strict_flags(
245+
standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"]
246+
)
247+
248+
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ:
249+
set_array_api_strict_flags(
250+
data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true"
251+
)
252+
253+
if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ:
254+
set_array_api_strict_flags(
255+
enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",")
256+
)

0 commit comments

Comments
 (0)