Skip to content

Commit 56831fa

Browse files
Prepare for upcoming changes in Intel(R) MKL
Intel(R) Math Kernel Library (MKL) has long had DFTI_PACKED_FORMAT DFTI_PACK deprecated, and that feature is going to be removed. This commit removes use of real-to-real packed transform calls and replaces them with use of real-to-complex and complex-to-real with pre-/post- processing. While less efficient, this only affects mkl_fft.rfft, mkl_fft.irfft which corresponds to a now deprecated scipy.fftpack.rfft. The tests were updated to work with NumPy 1.19
1 parent 05c22b3 commit 56831fa

File tree

3 files changed

+203
-383
lines changed

3 files changed

+203
-383
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 199 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ cdef void _capsule_destructor(object caps):
5656
PyMem_Free(_cache)
5757
if (status != 0):
5858
raise ValueError("Internal Error: Freeing DFTI Cache returned with error = {}".format(status))
59-
59+
6060

6161
def _tls_dfti_cache_capsule():
6262
cdef DftiCache *_cache_struct
@@ -72,7 +72,7 @@ def _tls_dfti_cache_capsule():
7272
capsule = getattr(_tls, 'capsule', None)
7373
if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
7474
raise ValueError("Internal Error: invalid capsule stored in TLS")
75-
return capsule
75+
return capsule
7676

7777

7878
cdef extern from "Python.h":
@@ -113,11 +113,6 @@ cdef extern from "src/mklfft.h":
113113
int float_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
114114
int float_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)
115115

116-
int double_double_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
117-
int double_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
118-
int float_float_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
119-
int float_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
120-
121116
int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
122117
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
123118

@@ -408,101 +403,239 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
408403

409404
def rfft(x, n=None, axis=-1, overwrite_x=False):
410405
"""Packed real-valued harmonics of FFT of a real sequence x"""
411-
return _rrfft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=+1)
406+
return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x)
412407

413408

414409
def irfft(x, n=None, axis=-1, overwrite_x=False):
415410
"""Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT"""
416-
return _rrfft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=-1)
411+
return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x)
412+
413+
414+
cdef object _rc_to_rr(cnp.ndarray rc_arr, int n, int axis, int xnd, int x_type):
415+
cdef object res
416+
cdef object sl, sl1, sl2
417+
418+
inp = <object>rc_arr
419+
420+
slice_ = [slice(None, None, None)] * xnd
421+
sl_0 = list(slice_)
422+
sl_0[axis] = 0
423+
424+
sl_1 = list(slice_)
425+
sl_1[axis] = 1
426+
if (inp.flags['C'] and inp.strides[axis] == inp.itemsize):
427+
res = inp
428+
res = res.view(dtype=np.single if (x_type == cnp.NPY_FLOAT) else np.double)
429+
res[tuple(sl_1)] = res[tuple(sl_0)]
430+
431+
slice_[axis] = slice(1, n + 1, None)
432+
433+
return res[tuple(slice_)]
434+
else:
435+
res_shape = list(inp.shape)
436+
res_shape[axis] = n
437+
res = np.empty(tuple(res_shape), dtype=np.single if (x_type == cnp.NPY_FLOAT) else np.double)
438+
439+
res[tuple(sl_0)] = inp[tuple(sl_0)].real
440+
sl_dst_real = list(slice_)
441+
sl_dst_real[axis] = slice(1, None, 2)
442+
sl_src_real = list(slice_)
443+
sl_src_real[axis] = slice(1, None, None)
444+
res[tuple(sl_dst_real)] = inp[tuple(sl_src_real)].real
445+
sl_dst_imag = list(slice_)
446+
sl_dst_imag[axis] = slice(2, None, 2)
447+
sl_src_imag = list(slice_)
448+
sl_src_imag[axis] = slice(1, inp.shape[axis] if (n & 1) else inp.shape[axis] - 1, None)
449+
res[tuple(sl_dst_imag)] = inp[tuple(sl_src_imag)].imag
450+
451+
return res[tuple(slice_)]
452+
453+
cdef object _rr_to_rc(cnp.ndarray rr_arr, int n, int axis, int xnd, int x_type):
454+
455+
inp = <object> rr_arr
456+
457+
rc_shape = list(inp.shape)
458+
rc_shape[axis] = (n // 2 + 1)
459+
rc_shape = tuple(rc_shape)
460+
461+
rc_dtype = np.cdouble if x_type == cnp.NPY_DOUBLE else np.csingle
462+
rc = np.empty(rc_shape, dtype=rc_dtype, order='C')
463+
464+
slice_ = [slice(None, None, None)] * xnd
465+
sl_src_real = list(slice_)
466+
sl_src_imag = list(slice_)
467+
sl_src_real[axis] = slice(1, n, 2)
468+
sl_src_imag[axis] = slice(2, n, 2)
469+
470+
sl_dest_real = list(slice_)
471+
sl_dest_real[axis] = slice(1, None, None)
472+
sl_dest_imag = list(slice_)
473+
sl_dest_imag[axis] = slice(1, (n+1)//2, None)
474+
475+
sl_0 = list(slice_)
476+
sl_0[axis] = 0
477+
478+
rc_real = rc.real
479+
rc_imag = rc.imag
480+
481+
rc_real[tuple(sl_dest_real)] = inp[tuple(sl_src_real)]
482+
rc_imag[tuple(sl_dest_imag)] = inp[tuple(sl_src_imag)]
483+
rc_real[tuple(sl_0)] = inp[tuple(sl_0)]
484+
rc_imag[tuple(sl_0)] = 0
485+
if (n & 1 == 0):
486+
sl_last = list(slice_)
487+
sl_last[axis] = -1
488+
rc_imag[tuple(sl_last)] = 0
489+
490+
return rc
491+
492+
493+
def _repack_rr_to_rc(x, n, axis):
494+
"""Debugging utility"""
495+
cdef cnp.ndarray x_arr
496+
cdef int n_ = n, axis_ = axis
497+
cdef x_type
498+
499+
x_arr = <cnp.ndarray> np.asarray(x)
500+
x_type = cnp.PyArray_TYPE(x_arr)
501+
return _rr_to_rc(x, n_, axis_, cnp.PyArray_NDIM(x_arr), x_type)
502+
503+
504+
def _repack_rc_to_rr(x, n, axis):
505+
"""Debugging utility"""
506+
cdef cnp.ndarray x_arr
507+
cdef int n_ = n, axis_ = axis
508+
cdef c_type, x_type
509+
510+
x_arr = <cnp.ndarray> np.asarray(x)
511+
c_type = cnp.PyArray_TYPE(x_arr)
512+
x_type = cnp.NPY_DOUBLE if c_type == cnp.NPY_CDOUBLE else cnp.NPY_FLOAT
513+
return _rc_to_rr(x, n_, axis_, cnp.PyArray_NDIM(x_arr), x_type)
417514

418515

419-
def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
516+
def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False):
420517
"""
421518
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
519+
520+
This done by using rfft_numpy and post-processing the result.
521+
Thus overwrite_arg is effectively discarded.
522+
523+
Functionally equivalent to scipy.fftpack.rfft
422524
"""
423525
cdef cnp.ndarray x_arr "x_arrayObject"
424526
cdef cnp.ndarray f_arr "f_arrayObject"
425527
cdef int xnd, err, n_max = 0, in_place, dir_
426528
cdef long n_, axis_
427-
cdef int x_type, status
529+
cdef int HALF_HARMONICS = 0 # give only positive index harmonics
530+
cdef int x_type, status, f_type
428531
cdef char * c_error_msg = NULL
429532
cdef bytes py_error_msg
430533
cdef DftiCache *_cache
431534

432-
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
535+
x_arr = __process_arguments(x, n, axis, overwrite_arg, <object>(+1),
433536
&axis_, &n_, &in_place, &xnd, &dir_, 1)
434537

435538
x_type = cnp.PyArray_TYPE(x_arr)
436539

437540
if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_DOUBLE:
438-
# we can operate in place if requested.
439-
if in_place:
440-
if not cnp.PyArray_ISONESEGMENT(x_arr):
441-
in_place = 0 if internal_overlap(x_arr) else 1;
541+
in_place = 0
442542
elif x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
443543
raise TypeError("1st argument must be a real sequence")
444544
else:
445-
# we must cast the input and allocate the output,
446-
# so we cast to double and operate in place
447545
try:
448546
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
449547
x_arr, cnp.NPY_DOUBLE, cnp.NPY_BEHAVED)
450548
except:
451549
raise TypeError("1st argument must be a real sequence")
452550
x_type = cnp.PyArray_TYPE(x_arr)
453-
in_place = 1
551+
in_place = 0
454552

455-
if in_place:
456-
_cache_capsule = _tls_dfti_cache_capsule()
457-
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
458-
if x_type is cnp.NPY_DOUBLE:
459-
if dir_ < 0:
460-
status = double_mkl_irfft_in(x_arr, n_, <int> axis_, _cache)
461-
else:
462-
status = double_mkl_rfft_in(x_arr, n_, <int> axis_, _cache)
463-
elif x_type is cnp.NPY_FLOAT:
464-
if dir_ < 0:
465-
status = float_mkl_irfft_in(x_arr, n_, <int> axis_, _cache)
466-
else:
467-
status = float_mkl_rfft_in(x_arr, n_, <int> axis_, _cache)
468-
else:
469-
status = 1
553+
f_type = cnp.NPY_CFLOAT if x_type is cnp.NPY_FLOAT else cnp.NPY_CDOUBLE
554+
f_arr = __allocate_result(x_arr, n_ // 2 + 1, axis_, f_type);
470555

471-
if status:
472-
c_error_msg = mkl_dfti_error(status)
473-
py_error_msg = c_error_msg
474-
raise ValueError("Internal error occurred: {}".format(py_error_msg))
556+
_cache_capsule = _tls_dfti_cache_capsule()
557+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
558+
if x_type is cnp.NPY_DOUBLE:
559+
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
560+
else:
561+
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
475562

476-
n_max = <long> cnp.PyArray_DIM(x_arr, axis_)
477-
if (n_ < n_max):
478-
ind = [slice(0, None, None), ] * xnd
479-
ind[axis_] = slice(0, n_, None)
480-
x_arr = x_arr[tuple(ind)]
563+
if (status):
564+
c_error_msg = mkl_dfti_error(status)
565+
py_error_msg = c_error_msg
566+
raise ValueError("Internal error occurred: {}".format(py_error_msg))
481567

482-
return x_arr
568+
# post-process and return
569+
return _rc_to_rr(f_arr, n_, axis_, xnd, x_type)
570+
571+
572+
def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False):
573+
"""
574+
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
575+
576+
This done by using rfft_numpy and post-processing the result.
577+
Thus overwrite_arg is effectively discarded.
578+
579+
Functionally equivalent to scipy.fftpack.irfft
580+
"""
581+
cdef cnp.ndarray x_arr "x_arrayObject"
582+
cdef cnp.ndarray f_arr "f_arrayObject"
583+
cdef int xnd, err, n_max = 0, in_place, dir_, int_n
584+
cdef long n_, axis_
585+
cdef int x_type, rc_type, status
586+
cdef int direction = 1 # dummy, only used for the sake of arg-processing
587+
cdef char * c_error_msg = NULL
588+
cdef bytes py_error_msg
589+
cdef DftiCache *_cache
590+
591+
x_arr = __process_arguments(x, n, axis, overwrite_arg, <object>(-1),
592+
&axis_, &n_, &in_place, &xnd, &dir_, 1)
593+
594+
x_type = cnp.PyArray_TYPE(x_arr)
595+
596+
if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_DOUBLE:
597+
pass
483598
else:
484-
f_arr = __allocate_result(x_arr, n_, axis_, x_type);
599+
# we must cast the input and allocate the output,
600+
# so we cast to complex double and operate in place
601+
try:
602+
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
603+
x_arr, cnp.NPY_DOUBLE, cnp.NPY_BEHAVED)
604+
except:
605+
raise ValueError("First argument should be a real or a complex sequence of single or double precision")
606+
x_type = cnp.PyArray_TYPE(x_arr)
607+
in_place = 1
485608

486-
# call out-of-place FFT
609+
# need to convert this into complex array
610+
rc_obj = _rr_to_rc(x_arr, n_, axis_, xnd, x_type)
611+
rc_arr = <cnp.ndarray> rc_obj
612+
613+
rc_type = cnp.NPY_CFLOAT if x_type is cnp.NPY_FLOAT else cnp.NPY_CDOUBLE
614+
in_place = False
615+
if in_place:
616+
f_arr = x_arr
617+
else:
618+
f_arr = __allocate_result(x_arr, n_, axis_, x_type)
619+
620+
# call out-of-place FFT
621+
if rc_type is cnp.NPY_CFLOAT:
487622
_cache_capsule = _tls_dfti_cache_capsule()
488623
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
489-
if x_type is cnp.NPY_DOUBLE:
490-
if dir_ < 0:
491-
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
492-
else:
493-
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
494-
else:
495-
if dir_ < 0:
496-
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
497-
else:
498-
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
624+
status = cfloat_float_mkl_irfft_out(rc_arr, n_, <int> axis_, f_arr, _cache)
625+
elif rc_type is cnp.NPY_CDOUBLE:
626+
_cache_capsule = _tls_dfti_cache_capsule()
627+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
628+
status = cdouble_double_mkl_irfft_out(rc_arr, n_, <int> axis_, f_arr, _cache)
629+
else:
630+
raise ValueError("Internal mkl_fft error occurred: Unrecognized rc_type")
499631

500-
if (status):
501-
c_error_msg = mkl_dfti_error(status)
502-
py_error_msg = c_error_msg
503-
raise ValueError("Internal error occurred: {}".format(py_error_msg))
632+
if (status):
633+
c_error_msg = mkl_dfti_error(status)
634+
py_error_msg = c_error_msg
635+
raise ValueError("Internal error occurred: {}".format(str(py_error_msg)))
636+
637+
return f_arr
504638

505-
return f_arr
506639

507640
# this routine is functionally equivalent to numpy.fft.rfft
508641
def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
@@ -582,13 +715,13 @@ cdef int _is_integral(object num):
582715
return _integral
583716

584717

585-
# this routine is functionally equivalent to numpy.fft.rfft
718+
# this routine is functionally equivalent to numpy.fft.irfft
586719
def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
587720
"""
588721
Uses MKL to perform 1D FFT on the real input array x along the given axis,
589722
producing complex output, but giving only half of the harmonics.
590723
591-
cf. numpy.fft.rfft
724+
cf. numpy.fft.irfft
592725
"""
593726
cdef cnp.ndarray x_arr "x_arrayObject"
594727
cdef cnp.ndarray f_arr "f_arrayObject"
@@ -891,8 +1024,8 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1):
8911024
if _direct:
8921025
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction)
8931026
else:
894-
return _iter_fftnd(x, s=shape, axes=axes,
895-
overwrite_arg=overwrite_x,
1027+
return _iter_fftnd(x, s=shape, axes=axes,
1028+
overwrite_arg=overwrite_x,
8961029
function=fft if direction == 1 else ifft)
8971030

8981031

@@ -933,7 +1066,7 @@ def _remove_axis(s, axes, axis_to_remove):
9331066

9341067

9351068
cdef cnp.ndarray _trim_array(cnp.ndarray arr, object s, object axes):
936-
"""Forms a view into subarray of arr if any element of shape parameter s is
1069+
"""Forms a view into subarray of arr if any element of shape parameter s is
9371070
smaller than the corresponding element of the shape of the input array arr,
9381071
otherwise returns the input array"""
9391072
arr_shape = (<object> arr).shape

0 commit comments

Comments
 (0)