Skip to content

Commit d7bbf5c

Browse files
authored
FFT. Inverse version for fft(), fft2() and fftn() (#350)
1 parent df174f8 commit d7bbf5c

File tree

4 files changed

+33
-19
lines changed

4 files changed

+33
-19
lines changed

dpnp/backend/backend_iface_fft.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
* @param[in] shape_size Number of elements in @ref input_shape or @ref output_shape arrays.
7070
* @param[in] axis Axis ID to compute by.
7171
* @param[in] input_boundarie Limit number of elements for @ref axis.
72+
* @param[in] inverse Using inverse algorithm.
7273
*/
7374
template <typename _DataType>
7475
INP_DLLEXPORT void dpnp_fft_fft_c(const void* array_in,
@@ -77,6 +78,7 @@ INP_DLLEXPORT void dpnp_fft_fft_c(const void* array_in,
7778
const long* output_shape,
7879
size_t shape_size,
7980
long axis,
80-
long input_boundarie);
81+
long input_boundarie,
82+
size_t inverse);
8183

8284
#endif // BACKEND_IFACE_FFT_H

dpnp/backend/dpnp_kernels_fft.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ void dpnp_fft_fft_c(const void* array1_in,
4242
const long* output_shape,
4343
size_t shape_size,
4444
long axis,
45-
long input_boundarie)
45+
long input_boundarie,
46+
size_t inverse)
4647
{
4748
const size_t result_size = std::accumulate(output_shape, output_shape + shape_size, 1, std::multiplies<size_t>());
4849
if (!(result_size && shape_size))
@@ -51,6 +52,7 @@ void dpnp_fft_fft_c(const void* array1_in,
5152
}
5253

5354
cl::sycl::event event;
55+
const double kernel_pi = inverse ? -M_PI : M_PI;
5456

5557
const _DataType_input* array_1 = reinterpret_cast<const _DataType_input*>(array1_in);
5658
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1);
@@ -107,7 +109,7 @@ void dpnp_fft_fft_c(const void* array1_in,
107109
}
108110

109111
const size_t output_local_id = xyz_thread[axis];
110-
const double angle = 2.0 * M_PI * it * output_local_id / axis_length;
112+
const double angle = 2.0 * kernel_pi * it * output_local_id / axis_length;
111113

112114
const double angle_cos = cl::sycl::cos(angle);
113115
const double angle_sin = cl::sycl::sin(angle);
@@ -116,6 +118,12 @@ void dpnp_fft_fft_c(const void* array1_in,
116118
sum_imag += -in_real * angle_sin + in_imag * angle_cos;
117119
}
118120

121+
if (inverse)
122+
{
123+
sum_real = sum_real / input_boundarie;
124+
sum_imag = sum_imag / input_boundarie;
125+
}
126+
119127
result[output_id] = _DataType_output(sum_real, sum_imag);
120128
};
121129

dpnp/fft/dpnp_algo_fft.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ __all__ = [
4242
"dpnp_fft"
4343
]
4444

45-
ctypedef void(*fptr_dpnp_fft_fft_t)(void *, void * , long * , long * , size_t, long, long)
45+
ctypedef void(*fptr_dpnp_fft_fft_t)(void *, void * , long * , long * , size_t, long, long, size_t)
4646

4747

48-
cpdef dparray dpnp_fft(dparray input, size_t input_boundarie, size_t output_boundarie, long axis):
48+
cpdef dparray dpnp_fft(dparray input, size_t input_boundarie, size_t output_boundarie, long axis, size_t inverse):
4949

5050
cdef dparray_shape_type input_shape = input.shape
5151
cdef dparray_shape_type output_shape = input_shape
@@ -66,6 +66,6 @@ cpdef dparray dpnp_fft(dparray input, size_t input_boundarie, size_t output_boun
6666
cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
6767
# call FPTR function
6868
func(input.get_data(), result.get_data(), input_shape.data(),
69-
output_shape.data(), input_shape.size(), axis_norm, input_boundarie)
69+
output_shape.data(), input_shape.size(), axis_norm, input_boundarie, inverse)
7070

7171
return result

dpnp/fft/dpnp_iface_fft.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def fft(x1, n=None, axis=-1, norm=None):
100100
else:
101101
output_boundarie = input_boundarie
102102

103-
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param)
103+
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param, False)
104104

105105
return call_origin(numpy.fft.fft, x1, n, axis, norm)
106106

@@ -197,7 +197,7 @@ def ifft(x1, n=None, axis=-1, norm=None):
197197

198198
is_x1_dparray = isinstance(x1, dparray)
199199

200-
if (not use_origin_backend(x1) and is_x1_dparray and 0):
200+
if (not use_origin_backend(x1) and is_x1_dparray):
201201
if axis is None:
202202
axis_param = -1 # the most right dimension (default value)
203203
else:
@@ -217,7 +217,7 @@ def ifft(x1, n=None, axis=-1, norm=None):
217217
else:
218218
output_boundarie = input_boundarie
219219

220-
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param)
220+
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param, True)
221221

222222
return call_origin(numpy.fft.ifft, x1, n, axis, norm)
223223

@@ -240,11 +240,11 @@ def ifft2(x1, s=None, axes=(-2, -1), norm=None):
240240

241241
is_x1_dparray = isinstance(x1, dparray)
242242

243-
if (not use_origin_backend(x1) and is_x1_dparray and 0):
243+
if (not use_origin_backend(x1) and is_x1_dparray):
244244
if norm is not None:
245245
pass
246246
else:
247-
return fftn(x1, s, axes, norm)
247+
return ifftn(x1, s, axes, norm)
248248

249249
return call_origin(numpy.fft.ifft2, x1, s, axes, norm)
250250

@@ -267,7 +267,7 @@ def ifftn(x1, s=None, axes=None, norm=None):
267267

268268
is_x1_dparray = isinstance(x1, dparray)
269269

270-
if (not use_origin_backend(x1) and is_x1_dparray and 0):
270+
if (not use_origin_backend(x1) and is_x1_dparray):
271271
if s is None:
272272
boundaries = tuple([x1.shape[i] for i in range(x1.ndim)])
273273
else:
@@ -291,7 +291,7 @@ def ifftn(x1, s=None, axes=None, norm=None):
291291
except IndexError:
292292
checker_throw_axis_error("fft.ifftn", "is out of bounds", param_axis, f"< {len(boundaries)}")
293293

294-
x1_iter = fft(x1_iter, n=param_n, axis=param_axis, norm=norm)
294+
x1_iter = ifft(x1_iter, n=param_n, axis=param_axis, norm=norm)
295295

296296
return x1_iter
297297

@@ -332,9 +332,13 @@ def irfft(x1, n=None, axis=-1, norm=None):
332332
elif norm is not None:
333333
pass
334334
else:
335-
output_boundarie = input_boundarie
335+
output_boundarie = 2 * (input_boundarie - 1)
336336

337-
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param)
337+
result = dpnp_fft(x1, input_boundarie, output_boundarie, axis_param, True)
338+
tmp = dparray(result.shape, dtype=dpnp.float64)
339+
for it in range(tmp.size):
340+
tmp[it] = result[it].real
341+
return tmp
338342

339343
return call_origin(numpy.fft.irfft, x1, n, axis, norm)
340344

@@ -357,11 +361,11 @@ def irfft2(x1, s=None, axes=(-2, -1), norm=None):
357361

358362
is_x1_dparray = isinstance(x1, dparray)
359363

360-
if (not use_origin_backend(x1) and is_x1_dparray and 0):
364+
if (not use_origin_backend(x1) and is_x1_dparray):
361365
if norm is not None:
362366
pass
363367
else:
364-
return fftn(x1, s, axes, norm)
368+
return irfftn(x1, s, axes, norm)
365369

366370
return call_origin(numpy.fft.irfft2, x1, s, axes, norm)
367371

@@ -408,7 +412,7 @@ def irfftn(x1, s=None, axes=None, norm=None):
408412
except IndexError:
409413
checker_throw_axis_error("fft.irfftn", "is out of bounds", param_axis, f"< {len(boundaries)}")
410414

411-
x1_iter = fft(x1_iter, n=param_n, axis=param_axis, norm=norm)
415+
x1_iter = irfft(x1_iter, n=param_n, axis=param_axis, norm=norm)
412416

413417
return x1_iter
414418

@@ -451,7 +455,7 @@ def rfft(x1, n=None, axis=-1, norm=None):
451455
else:
452456
output_boundarie = input_boundarie // 2 + 1 # rfft specific requirenment
453457

454-
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param)
458+
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param, False)
455459

456460
return call_origin(numpy.fft.rfft, x1, n, axis, norm)
457461

0 commit comments

Comments
 (0)