|
4 | 4 | from .. import (asarray, unique_all, unique_counts, unique_inverse,
|
5 | 5 | unique_values, nonzero)
|
6 | 6 |
|
| 7 | +import array_api_strict as xp |
| 8 | + |
7 | 9 | import pytest
|
8 | 10 |
|
9 | 11 | @pytest.fixture(autouse=True)
|
@@ -76,3 +78,92 @@ def test_data_dependent_shapes():
|
76 | 78 | pytest.raises(RuntimeError, lambda: unique_values(a))
|
77 | 79 | pytest.raises(RuntimeError, lambda: nonzero(a))
|
78 | 80 | pytest.raises(RuntimeError, lambda: a[mask])
|
| 81 | + |
| 82 | +linalg_examples = { |
| 83 | + 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), |
| 84 | + 'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), |
| 85 | + 'det': lambda: xp.linalg.det(xp.eye(3)), |
| 86 | + 'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)), |
| 87 | + 'eigh': lambda: xp.linalg.eigh(xp.eye(3)), |
| 88 | + 'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)), |
| 89 | + 'inv': lambda: xp.linalg.inv(xp.eye(3)), |
| 90 | + 'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)), |
| 91 | + 'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)), |
| 92 | + 'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2), |
| 93 | + 'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)), |
| 94 | + 'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)), |
| 95 | + 'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), |
| 96 | + 'pinv': lambda: xp.linalg.pinv(xp.eye(3)), |
| 97 | + 'qr': lambda: xp.linalg.qr(xp.eye(3)), |
| 98 | + 'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)), |
| 99 | + 'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)), |
| 100 | + 'svd': lambda: xp.linalg.svd(xp.eye(3)), |
| 101 | + 'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)), |
| 102 | + 'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)), |
| 103 | + 'trace': lambda: xp.linalg.trace(xp.eye(3)), |
| 104 | + 'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), |
| 105 | + 'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])), |
| 106 | +} |
| 107 | + |
| 108 | +assert set(linalg_examples) == set(xp.linalg.__all__) |
| 109 | + |
| 110 | +linalg_main_namespace_examples = { |
| 111 | + 'matmul': lambda: xp.matmul(xp.eye(3), xp.eye(3)), |
| 112 | + 'matrix_transpose': lambda: xp.matrix_transpose(xp.eye(3)), |
| 113 | + 'tensordot': lambda: xp.tensordot(xp.eye(3), xp.eye(3)), |
| 114 | + 'vecdot': lambda: xp.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), |
| 115 | +} |
| 116 | + |
| 117 | +assert set(linalg_main_namespace_examples) == set(xp.__all__) & set(xp.linalg.__all__) |
| 118 | + |
| 119 | +@pytest.mark.parametrize('func_name', linalg_examples.keys()) |
| 120 | +def test_linalg(func_name): |
| 121 | + func = linalg_examples[func_name] |
| 122 | + if func_name in linalg_main_namespace_examples: |
| 123 | + main_namespace_func = linalg_main_namespace_examples[func_name] |
| 124 | + else: |
| 125 | + main_namespace_func = lambda: None |
| 126 | + |
| 127 | + # First make sure the example actually works |
| 128 | + func() |
| 129 | + main_namespace_func() |
| 130 | + |
| 131 | + set_array_api_strict_flags(enabled_extensions=()) |
| 132 | + pytest.raises(RuntimeError, func) |
| 133 | + main_namespace_func() |
| 134 | + |
| 135 | + set_array_api_strict_flags(enabled_extensions=('linalg',)) |
| 136 | + func() |
| 137 | + main_namespace_func() |
| 138 | + |
| 139 | +fft_examples = { |
| 140 | + 'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])), |
| 141 | + 'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])), |
| 142 | + 'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
| 143 | + 'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
| 144 | + 'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])), |
| 145 | + 'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])), |
| 146 | + 'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])), |
| 147 | + 'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
| 148 | + 'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])), |
| 149 | + 'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])), |
| 150 | + 'fftfreq': lambda: xp.fft.fftfreq(4), |
| 151 | + 'rfftfreq': lambda: xp.fft.rfftfreq(4), |
| 152 | + 'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])), |
| 153 | + 'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])), |
| 154 | +} |
| 155 | + |
| 156 | +assert set(fft_examples) == set(xp.fft.__all__) |
| 157 | + |
| 158 | +@pytest.mark.parametrize('func_name', fft_examples.keys()) |
| 159 | +def test_fft(func_name): |
| 160 | + func = fft_examples[func_name] |
| 161 | + |
| 162 | + # First make sure the example actually works |
| 163 | + func() |
| 164 | + |
| 165 | + set_array_api_strict_flags(enabled_extensions=()) |
| 166 | + pytest.raises(RuntimeError, func) |
| 167 | + |
| 168 | + set_array_api_strict_flags(enabled_extensions=('fft',)) |
| 169 | + func() |
0 commit comments