Skip to content

Commit a25a98f

Browse files
Merge pull request #69 from cako/master
Fixes NumPy interface bug which did not allow for list/ndarray `axes` or `s` parameters
2 parents 78564b7 + cd8b39a commit a25a98f

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

mkl_fft/_numpy_fft.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def frwd_sc_1d(n, s):
7676

7777

7878
def frwd_sc_nd(s, axes, x_shape):
79-
ss = s if s else x_shape
80-
if axes:
79+
ss = s if s is not None else x_shape
80+
if axes is not None:
8181
nn = prod([ss[ai] for ai in axes])
8282
else:
8383
nn = prod(ss)
@@ -203,7 +203,7 @@ def fft(a, n=None, axis=-1, norm=None):
203203
mkl_fft.fft,
204204
(x,),
205205
{'n':n, 'axis': axis})
206-
elif norm is "forward":
206+
elif norm == "forward":
207207
output = trycall(
208208
mkl_fft.fft,
209209
(x,),

mkl_fft/_pydfti.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ def iter_complementary(x, axes, func, kwargs, result):
932932
nd = x.ndim
933933
r = list(range(nd))
934934
sl = [slice(None, None, None)] * nd
935-
if not isinstance(axes, tuple):
935+
if not np.iterable(axes):
936936
axes = (axes,)
937937
for ai in axes:
938938
r[ai] = None

mkl_fft/tests/test_fftnd.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ def test_matrix5(self):
114114
rtol=r_tol, atol=a_tol
115115
)
116116

117+
def test_matrix6(self):
118+
"""fftn with tuple, list and ndarray axes and s"""
119+
for ar in [self.md, self.mz, self.mf, self.mc]:
120+
d = ar.copy()
121+
for norm in ["forward", "backward", "ortho"]:
122+
for container in [tuple, list, np.array]:
123+
axes = container(range(d.ndim))
124+
s = container(d.shape)
125+
kwargs = dict(s=s, axes=axes, norm=norm)
126+
r_tol, a_tol = _get_rtol_atol(d)
127+
t = mkl_fft._numpy_fft.fftn(mkl_fft._numpy_fft.ifftn(d, **kwargs), **kwargs)
128+
assert_allclose(d, t, rtol=r_tol, atol=a_tol, err_msg = "failed test for dtype {}, max abs diff: {}".format(d.dtype, np.max(np.abs(d-t))))
117129

118130

119131

0 commit comments

Comments
 (0)