Skip to content

Commit 2cae8e2

Browse files
Extended test_interfaces to cover numpy_fft too
1 parent 2df2cae commit 2cae8e2

File tree

1 file changed

+64
-8
lines changed

1 file changed

+64
-8
lines changed

mkl_fft/tests/test_interfaces.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,18 @@ def test_interfaces_has_scipy():
4141
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128])
4242
def test_scipy_fft(norm, dtype):
4343
x = np.ones(511, dtype=dtype)
44-
w = mfi.scipy_fft.fft(x, norm=norm)
45-
xx = mfi.scipy_fft.ifft(w, norm=norm)
44+
w = mfi.scipy_fft.fft(x, norm=norm, workers=None, plan=None)
45+
xx = mfi.scipy_fft.ifft(w, norm=norm, workers=None, plan=None)
46+
tol = 64 * np.finfo(np.dtype(dtype)).eps
47+
assert np.allclose(x, xx, atol=tol, rtol=tol)
48+
49+
50+
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
51+
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128])
52+
def test_numpy_fft(norm, dtype):
53+
x = np.ones(511, dtype=dtype)
54+
w = mfi.numpy_fft.fft(x, norm=norm)
55+
xx = mfi.numpy_fft.ifft(w, norm=norm)
4656
tol = 64 * np.finfo(np.dtype(dtype)).eps
4757
assert np.allclose(x, xx, atol=tol, rtol=tol)
4858

@@ -51,8 +61,18 @@ def test_scipy_fft(norm, dtype):
5161
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
5262
def test_scipy_rfft(norm, dtype):
5363
x = np.ones(511, dtype=dtype)
54-
w = mfi.scipy_fft.rfft(x, norm=norm)
55-
xx = mfi.scipy_fft.irfft(w, n=x.shape[0], norm=norm)
64+
w = mfi.scipy_fft.rfft(x, norm=norm, workers=None, plan=None)
65+
xx = mfi.scipy_fft.irfft(w, n=x.shape[0], norm=norm, workers=None, plan=None)
66+
tol = 64 * np.finfo(np.dtype(dtype)).eps
67+
assert np.allclose(x, xx, atol=tol, rtol=tol)
68+
69+
70+
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
71+
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
72+
def test_numpy_rfft(norm, dtype):
73+
x = np.ones(511, dtype=dtype)
74+
w = mfi.numpy_fft.rfft(x, norm=norm)
75+
xx = mfi.numpy_fft.irfft(w, n=x.shape[0], norm=norm)
5676
tol = 64 * np.finfo(np.dtype(dtype)).eps
5777
assert np.allclose(x, xx, atol=tol, rtol=tol)
5878

@@ -61,8 +81,18 @@ def test_scipy_rfft(norm, dtype):
6181
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128])
6282
def test_scipy_fftn(norm, dtype):
6383
x = np.ones((37, 83), dtype=dtype)
64-
w = mfi.scipy_fft.fftn(x, norm=norm)
65-
xx = mfi.scipy_fft.ifftn(w, norm=norm)
84+
w = mfi.scipy_fft.fftn(x, norm=norm, workers=None, plan=None)
85+
xx = mfi.scipy_fft.ifftn(w, norm=norm, workers=None, plan=None)
86+
tol = 64 * np.finfo(np.dtype(dtype)).eps
87+
assert np.allclose(x, xx, atol=tol, rtol=tol)
88+
89+
90+
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
91+
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128])
92+
def test_numpy_fftn(norm, dtype):
93+
x = np.ones((37, 83), dtype=dtype)
94+
w = mfi.numpy_fft.fftn(x, norm=norm)
95+
xx = mfi.numpy_fft.ifftn(w, norm=norm)
6696
tol = 64 * np.finfo(np.dtype(dtype)).eps
6797
assert np.allclose(x, xx, atol=tol, rtol=tol)
6898

@@ -71,7 +101,33 @@ def test_scipy_fftn(norm, dtype):
71101
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
72102
def test_scipy_rftn(norm, dtype):
73103
x = np.ones((37, 83), dtype=dtype)
74-
w = mfi.scipy_fft.rfftn(x, norm=norm)
75-
xx = mfi.scipy_fft.ifftn(w, s=x.shape, norm=norm)
104+
w = mfi.scipy_fft.rfftn(x, norm=norm, workers=None, plan=None)
105+
xx = mfi.scipy_fft.ifftn(w, s=x.shape, norm=norm, workers=None, plan=None)
76106
tol = 64 * np.finfo(np.dtype(dtype)).eps
77107
assert np.allclose(x, xx, atol=tol, rtol=tol)
108+
109+
110+
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
111+
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
112+
def test_numpy_rftn(norm, dtype):
113+
x = np.ones((37, 83), dtype=dtype)
114+
w = mfi.numpy_fft.rfftn(x, norm=norm)
115+
xx = mfi.numpy_fft.ifftn(w, s=x.shape, norm=norm)
116+
tol = 64 * np.finfo(np.dtype(dtype)).eps
117+
assert np.allclose(x, xx, atol=tol, rtol=tol)
118+
119+
120+
@pytest.mark.parametrize('dtype', [np.float16, np.float128, np.complex256])
121+
def test_scipy_no_support_for(dtype):
122+
x = np.ones(16, dtype=dtype)
123+
w = mfi.scipy_fft.fft(x)
124+
assert w is NotImplemented
125+
126+
127+
def test_scipy_fft_arg_validate():
128+
with pytest.raises(ValueError):
129+
mfi.scipy_fft.fft([1,2,3,4], norm=b"invalid")
130+
131+
with pytest.raises(ValueError):
132+
mfi.scipy_fft.fft([1,2,3,4], plan="magic")
133+

0 commit comments

Comments
 (0)