Skip to content

Commit c941c47

Browse files
BUG: Fixed issue/37
See #37 As noted by the reported, the logic of _remove_axes was not right. We always have to remove the last axis (since _cook_nd_args rearranges it that way), but upon remove of the axis, we have to decrease all the axis higher than the pivot by one, since we take part at the pivot.
1 parent ef665fb commit c941c47

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,9 @@ def _remove_axis(s, axes, axis_to_remove):
878878
a2r = lens + axis_to_remove if axis_to_remove < 0 else axis_to_remove
879879

880880
ss = s[:a2r] + s[a2r+1:]
881-
aa = axes_normalized[:a2r] + tuple(ai - 1 for ai in axes_normalized[a2r+1:])
881+
pivot = axes_normalized[a2r]
882+
aa = tuple(ai if ai < pivot else ai - 1 for ai in axes_normalized[:a2r]) + \
883+
tuple(ai if ai < pivot else ai - 1 for ai in axes_normalized[a2r+1:])
882884
return ss, aa
883885

884886

@@ -938,7 +940,7 @@ def rfftn_numpy(x, s=None, axes=None):
938940
ss[-1] = a.shape[la]
939941
a = _fix_dimensions(a, tuple(ss), axes)
940942
if len(set(axes)) == len(axes) and len(axes) == a.ndim and len(axes) > 2:
941-
ss, aa = _remove_axis(s, axes, la)
943+
ss, aa = _remove_axis(s, axes, -1)
942944
ind = [slice(None,None,1),] * len(s)
943945
for ii in range(a.shape[la]):
944946
ind[la] = ii

0 commit comments

Comments
 (0)