Skip to content

Commit 4705b9f

Browse files
committed
Add tests for flags
1 parent 6a20e91 commit 4705b9f

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

array_api_strict/tests/test_flags.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags,
2+
reset_array_api_strict_flags)
3+
4+
from .. import (asarray, unique_all, unique_counts, unique_inverse,
5+
unique_values, nonzero)
6+
7+
import pytest
8+
9+
@pytest.fixture(autouse=True)
10+
def reset_flags():
11+
reset_array_api_strict_flags()
12+
yield
13+
reset_array_api_strict_flags()
14+
15+
def test_flags():
16+
# Test defaults
17+
flags = get_array_api_strict_flags()
18+
assert flags == {
19+
'standard_version': '2022.12',
20+
'data_dependent_shapes': True,
21+
'enabled_extensions': ('linalg', 'fft'),
22+
}
23+
24+
# Test setting flags
25+
set_array_api_strict_flags(data_dependent_shapes=False)
26+
flags = get_array_api_strict_flags()
27+
assert flags == {
28+
'standard_version': '2022.12',
29+
'data_dependent_shapes': False,
30+
'enabled_extensions': ('linalg', 'fft'),
31+
}
32+
set_array_api_strict_flags(enabled_extensions=('fft',))
33+
flags = get_array_api_strict_flags()
34+
assert flags == {
35+
'standard_version': '2022.12',
36+
'data_dependent_shapes': False,
37+
'enabled_extensions': ('fft',),
38+
}
39+
# Make sure setting the version to 2021.12 disables fft
40+
set_array_api_strict_flags(standard_version='2021.12')
41+
flags = get_array_api_strict_flags()
42+
assert flags == {
43+
'standard_version': '2021.12',
44+
'data_dependent_shapes': False,
45+
'enabled_extensions': ('linalg',),
46+
}
47+
48+
# Test setting flags with invalid values
49+
pytest.raises(ValueError, lambda:
50+
set_array_api_strict_flags(standard_version='2020.12'))
51+
pytest.raises(ValueError, lambda: set_array_api_strict_flags(
52+
enabled_extensions=('linalg', 'fft', 'invalid')))
53+
pytest.raises(ValueError, lambda: set_array_api_strict_flags(
54+
standard_version='2021.12',
55+
enabled_extensions=('linalg', 'fft')))
56+
57+
58+
def test_data_dependent_shapes():
59+
a = asarray([0, 0, 1, 2, 2])
60+
mask = asarray([True, False, True, False, True])
61+
62+
# Should not error
63+
unique_all(a)
64+
unique_counts(a)
65+
unique_inverse(a)
66+
unique_values(a)
67+
nonzero(a)
68+
a[mask]
69+
# TODO: add repeat when it is implemented
70+
71+
set_array_api_strict_flags(data_dependent_shapes=False)
72+
73+
pytest.raises(RuntimeError, lambda: unique_all(a))
74+
pytest.raises(RuntimeError, lambda: unique_counts(a))
75+
pytest.raises(RuntimeError, lambda: unique_inverse(a))
76+
pytest.raises(RuntimeError, lambda: unique_values(a))
77+
pytest.raises(RuntimeError, lambda: nonzero(a))
78+
pytest.raises(RuntimeError, lambda: a[mask])

0 commit comments

Comments
 (0)