Skip to content

Commit d47a7da

Browse files
authored
FFT. rfft(), rfft2() rfftn() kernel (#345)
1 parent 9a8771c commit d47a7da

File tree

4 files changed

+55
-40
lines changed

4 files changed

+55
-40
lines changed

dpnp/backend/backend_iface_fft.hpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,21 @@
6262
*
6363
* Compute the one-dimensional discrete Fourier Transform.
6464
*
65-
* @param[in] array1_in Input array 1.
66-
* @param[out] result1 Output array.
67-
* @param[in] input_shape Array with shape information for input array.
68-
* @param[in] output_shape Array with shape information for output array.
69-
* @param[in] shape_size Number of elements in @ref input_shape or @ref output_shape arrays.
70-
* @param[in] axis Axis ID to compute by.
65+
* @param[in] array_in Input array.
66+
* @param[out] result Output array.
67+
* @param[in] input_shape Array with shape information for input array.
68+
* @param[in] output_shape Array with shape information for output array.
69+
* @param[in] shape_size Number of elements in @ref input_shape or @ref output_shape arrays.
70+
* @param[in] axis Axis ID to compute by.
71+
* @param[in] input_boundarie Limit number of elements for @ref axis.
7172
*/
7273
template <typename _DataType>
73-
INP_DLLEXPORT void dpnp_fft_fft_c(const void* array1_in,
74-
void* result1,
74+
INP_DLLEXPORT void dpnp_fft_fft_c(const void* array_in,
75+
void* result,
7576
const long* input_shape,
7677
const long* output_shape,
7778
size_t shape_size,
78-
long axis);
79+
long axis,
80+
long input_boundarie);
7981

8082
#endif // BACKEND_IFACE_FFT_H

dpnp/backend/dpnp_kernels_fft.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ void dpnp_fft_fft_c(const void* array1_in,
4141
const long* input_shape,
4242
const long* output_shape,
4343
size_t shape_size,
44-
long axis)
44+
long axis,
45+
long input_boundarie)
4546
{
4647
const size_t result_size = std::accumulate(output_shape, output_shape + shape_size, 1, std::multiplies<size_t>());
4748
if (!(result_size && shape_size))
@@ -80,7 +81,7 @@ void dpnp_fft_fft_c(const void* array1_in,
8081
axis_iterator_thread[i] = xyz_thread[i];
8182
}
8283

83-
const long axis_length = output_shape[axis];
84+
const long axis_length = input_boundarie;
8485
for (long it = 0; it < axis_length; ++it)
8586
{
8687
double in_real = 0.0;

dpnp/fft/dpnp_algo_fft.pyx

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,29 +42,30 @@ __all__ = [
4242
"dpnp_fft"
4343
]
4444

45-
ctypedef void(*fptr_dpnp_fft_fft_t)(void * , void * , long * , long * , size_t, long)
45+
ctypedef void(*fptr_dpnp_fft_fft_t)(void *, void * , long * , long * , size_t, long, long)
4646

4747

48-
cpdef dparray dpnp_fft(dparray input, size_t axis_boundarie, long axis):
48+
cpdef dparray dpnp_fft(dparray input, size_t input_boundarie, size_t output_boundarie, long axis):
4949

5050
cdef dparray_shape_type input_shape = input.shape
5151
cdef dparray_shape_type output_shape = input_shape
5252

5353
cdef long axis_norm = normalize_axis((axis,), input_shape.size())[0]
54-
output_shape[axis_norm] = axis_boundarie
54+
output_shape[axis_norm] = output_boundarie
5555

5656
# convert string type names (dparray.dtype) to C enum DPNPFuncType
5757
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
5858

5959
# get the FPTR data structure
6060
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_FFT_FFT, param1_type, param1_type)
6161

62-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
62+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
6363
# ceate result array with type given by FPTR data
6464
cdef dparray result = dparray(output_shape, dtype=result_type)
6565

6666
cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
6767
# call FPTR function
68-
func(input.get_data(), result.get_data(), input_shape.data(), output_shape.data(), input_shape.size(), axis_norm)
68+
func(input.get_data(), result.get_data(), input_shape.data(),
69+
output_shape.data(), input_shape.size(), axis_norm, input_boundarie)
6970

7071
return result

dpnp/fft/dpnp_iface_fft.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,20 @@ def fft(x1, n=None, axis=-1, norm=None):
8787
axis_param = axis
8888

8989
if n is None:
90-
boundarie = x1.shape[axis_param]
90+
input_boundarie = x1.shape[axis_param]
9191
else:
92-
boundarie = n
92+
input_boundarie = n
9393

9494
if x1.size < 1:
9595
pass # let fallback to handle exception
96-
elif boundarie < 1:
96+
elif input_boundarie < 1:
9797
pass # let fallback to handle exception
9898
elif norm is not None:
9999
pass
100100
else:
101-
return dpnp_fft(x1, boundarie, axis_param)
101+
output_boundarie = input_boundarie
102+
103+
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param)
102104

103105
return call_origin(numpy.fft.fft, x1, n, axis, norm)
104106

@@ -202,18 +204,20 @@ def ifft(x1, n=None, axis=-1, norm=None):
202204
axis_param = axis
203205

204206
if n is None:
205-
boundarie = x1.shape[axis_param]
207+
input_boundarie = x1.shape[axis_param]
206208
else:
207-
boundarie = n
209+
input_boundarie = n
208210

209211
if x1.size < 1:
210212
pass # let fallback to handle exception
211-
elif boundarie < 1:
213+
elif input_boundarie < 1:
212214
pass # let fallback to handle exception
213215
elif norm is not None:
214216
pass
215217
else:
216-
return dpnp_fft(x1, boundarie, axis_param)
218+
output_boundarie = input_boundarie
219+
220+
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param)
217221

218222
return call_origin(numpy.fft.ifft, x1, n, axis, norm)
219223

@@ -317,18 +321,20 @@ def irfft(x1, n=None, axis=-1, norm=None):
317321
axis_param = axis
318322

319323
if n is None:
320-
boundarie = x1.shape[axis_param]
324+
input_boundarie = x1.shape[axis_param]
321325
else:
322-
boundarie = n
326+
input_boundarie = n
323327

324328
if x1.size < 1:
325329
pass # let fallback to handle exception
326-
elif boundarie < 1:
330+
elif input_boundarie < 1:
327331
pass # let fallback to handle exception
328332
elif norm is not None:
329333
pass
330334
else:
331-
return dpnp_fft(x1, boundarie, axis_param)
335+
output_boundarie = input_boundarie
336+
337+
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param)
332338

333339
return call_origin(numpy.fft.irfft, x1, n, axis, norm)
334340

@@ -408,6 +414,7 @@ def irfftn(x1, s=None, axes=None, norm=None):
408414

409415
return call_origin(numpy.fft.irfftn, x1, s, axes, norm)
410416

417+
411418
def rfft(x1, n=None, axis=-1, norm=None):
412419
"""
413420
Compute the one-dimensional discrete Fourier Transform for real input.
@@ -424,25 +431,27 @@ def rfft(x1, n=None, axis=-1, norm=None):
424431

425432
is_x1_dparray = isinstance(x1, dparray)
426433

427-
if (not use_origin_backend(x1) and is_x1_dparray and 0):
434+
if (not use_origin_backend(x1) and is_x1_dparray):
428435
if axis is None:
429-
axis_param = -1 # the most right dimension (default value)
436+
axis_param = -1 # the most right dimension (default value)
430437
else:
431438
axis_param = axis
432439

433440
if n is None:
434-
boundarie = x1.shape[axis_param]
441+
input_boundarie = x1.shape[axis_param]
435442
else:
436-
boundarie = n
443+
input_boundarie = n
437444

438445
if x1.size < 1:
439-
pass # let fallback to handle exception
440-
elif boundarie < 1:
441-
pass # let fallback to handle exception
446+
pass # let fallback to handle exception
447+
elif input_boundarie < 1:
448+
pass # let fallback to handle exception
442449
elif norm is not None:
443450
pass
444451
else:
445-
return dpnp_fft(x1, boundarie, axis_param)
452+
output_boundarie = input_boundarie // 2 + 1 # rfft specific requirenment
453+
454+
return dpnp_fft(x1, input_boundarie, output_boundarie, axis_param)
446455

447456
return call_origin(numpy.fft.rfft, x1, n, axis, norm)
448457

@@ -465,11 +474,11 @@ def rfft2(x1, s=None, axes=(-2, -1), norm=None):
465474

466475
is_x1_dparray = isinstance(x1, dparray)
467476

468-
if (not use_origin_backend(x1) and is_x1_dparray and 0):
477+
if (not use_origin_backend(x1) and is_x1_dparray):
469478
if norm is not None:
470479
pass
471480
else:
472-
return fftn(x1, s, axes, norm)
481+
return rfftn(x1, s, axes, norm)
473482

474483
return call_origin(numpy.fft.rfft2, x1, s, axes, norm)
475484

@@ -492,7 +501,7 @@ def rfftn(x1, s=None, axes=None, norm=None):
492501

493502
is_x1_dparray = isinstance(x1, dparray)
494503

495-
if (not use_origin_backend(x1) and is_x1_dparray and 0):
504+
if (not use_origin_backend(x1) and is_x1_dparray):
496505
if s is None:
497506
boundaries = tuple([x1.shape[i] for i in range(x1.ndim)])
498507
else:
@@ -505,6 +514,8 @@ def rfftn(x1, s=None, axes=None, norm=None):
505514

506515
if norm is not None:
507516
pass
517+
elif len(axes) < 1:
518+
pass # let fallback to handle exception
508519
else:
509520
x1_iter = x1
510521
iteration_list = list(range(len(axes_param)))
@@ -516,7 +527,7 @@ def rfftn(x1, s=None, axes=None, norm=None):
516527
except IndexError:
517528
checker_throw_axis_error("fft.rfftn", "is out of bounds", param_axis, f"< {len(boundaries)}")
518529

519-
x1_iter = fft(x1_iter, n=param_n, axis=param_axis, norm=norm)
530+
x1_iter = rfft(x1_iter, n=param_n, axis=param_axis, norm=norm)
520531

521532
return x1_iter
522533

0 commit comments

Comments
 (0)