Skip to content

Commit 78d368d

Browse files
committed
Add tests for disabling linalg and fft extensions
1 parent 689a776 commit 78d368d

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

array_api_strict/tests/test_flags.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from .. import (asarray, unique_all, unique_counts, unique_inverse,
55
unique_values, nonzero)
66

7+
import array_api_strict as xp
8+
79
import pytest
810

911
@pytest.fixture(autouse=True)
@@ -76,3 +78,92 @@ def test_data_dependent_shapes():
7678
pytest.raises(RuntimeError, lambda: unique_values(a))
7779
pytest.raises(RuntimeError, lambda: nonzero(a))
7880
pytest.raises(RuntimeError, lambda: a[mask])
81+
82+
linalg_examples = {
83+
'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)),
84+
'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])),
85+
'det': lambda: xp.linalg.det(xp.eye(3)),
86+
'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)),
87+
'eigh': lambda: xp.linalg.eigh(xp.eye(3)),
88+
'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)),
89+
'inv': lambda: xp.linalg.inv(xp.eye(3)),
90+
'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)),
91+
'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)),
92+
'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2),
93+
'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)),
94+
'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)),
95+
'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
96+
'pinv': lambda: xp.linalg.pinv(xp.eye(3)),
97+
'qr': lambda: xp.linalg.qr(xp.eye(3)),
98+
'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)),
99+
'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)),
100+
'svd': lambda: xp.linalg.svd(xp.eye(3)),
101+
'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)),
102+
'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)),
103+
'trace': lambda: xp.linalg.trace(xp.eye(3)),
104+
'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
105+
'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])),
106+
}
107+
108+
assert set(linalg_examples) == set(xp.linalg.__all__)
109+
110+
linalg_main_namespace_examples = {
111+
'matmul': lambda: xp.matmul(xp.eye(3), xp.eye(3)),
112+
'matrix_transpose': lambda: xp.matrix_transpose(xp.eye(3)),
113+
'tensordot': lambda: xp.tensordot(xp.eye(3), xp.eye(3)),
114+
'vecdot': lambda: xp.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
115+
}
116+
117+
assert set(linalg_main_namespace_examples) == set(xp.__all__) & set(xp.linalg.__all__)
118+
119+
@pytest.mark.parametrize('func_name', linalg_examples.keys())
120+
def test_linalg(func_name):
121+
func = linalg_examples[func_name]
122+
if func_name in linalg_main_namespace_examples:
123+
main_namespace_func = linalg_main_namespace_examples[func_name]
124+
else:
125+
main_namespace_func = lambda: None
126+
127+
# First make sure the example actually works
128+
func()
129+
main_namespace_func()
130+
131+
set_array_api_strict_flags(enabled_extensions=())
132+
pytest.raises(RuntimeError, func)
133+
main_namespace_func()
134+
135+
set_array_api_strict_flags(enabled_extensions=('linalg',))
136+
func()
137+
main_namespace_func()
138+
139+
fft_examples = {
140+
'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])),
141+
'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])),
142+
'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])),
143+
'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])),
144+
'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])),
145+
'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])),
146+
'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])),
147+
'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])),
148+
'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])),
149+
'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])),
150+
'fftfreq': lambda: xp.fft.fftfreq(4),
151+
'rfftfreq': lambda: xp.fft.rfftfreq(4),
152+
'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])),
153+
'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])),
154+
}
155+
156+
assert set(fft_examples) == set(xp.fft.__all__)
157+
158+
@pytest.mark.parametrize('func_name', fft_examples.keys())
159+
def test_fft(func_name):
160+
func = fft_examples[func_name]
161+
162+
# First make sure the example actually works
163+
func()
164+
165+
set_array_api_strict_flags(enabled_extensions=())
166+
pytest.raises(RuntimeError, func)
167+
168+
set_array_api_strict_flags(enabled_extensions=('fft',))
169+
func()

0 commit comments

Comments
 (0)