Skip to content

Commit 1464975

Browse files
_direct_fftnd checks for dtypes and converts if not as expected
1 parent a63ce01 commit 1464975

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,14 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
981981
in_place = 1 # a copy was made, so we can work in place.
982982

983983
x_type = cnp.PyArray_TYPE(x_arr)
984-
assert( x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_CFLOAT or x_type == cnp.NPY_DOUBLE or x_type == cnp.NPY_FLOAT);
984+
if (x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_CFLOAT or x_type == cnp.NPY_DOUBLE or x_type == cnp.NPY_FLOAT):
985+
pass
986+
else:
987+
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
988+
x_arr, cnp.NPY_CDOUBLE, cnp.NPY_BEHAVED | cnp.NPY_ENSURECOPY)
989+
x_type = cnp.PyArray_TYPE(x_arr)
990+
assert x_type == cnp.NPY_CDOUBLE
991+
in_place = 1
985992

986993
if in_place:
987994
in_place = 1 if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_CFLOAT else 0
@@ -1076,7 +1083,7 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, doubl
10761083
if _direct:
10771084
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc)
10781085
else:
1079-
if (shape is None and x.dtype in [np.complex64, np.complex128, np.float32, np.float64]):
1086+
if (shape is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]):
10801087
x = np.asarray(x)
10811088
res = np.empty(x.shape, dtype=_output_dtype(x.dtype))
10821089
return iter_complementary(

0 commit comments

Comments
 (0)