Skip to content

Commit 812f004

Browse files
authored
fix python layer in fft module (#1127)
* fix python layer in fft module
1 parent 90fbe2e commit 812f004

File tree

4 files changed

+52
-200
lines changed

4 files changed

+52
-200
lines changed

dpnp/backend/include/dpnp_iface_fft.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
* @param[in] axis Axis ID to compute by.
6161
* @param[in] input_boundarie Limit number of elements for @ref axis.
6262
* @param[in] inverse Using inverse algorithm.
63-
* @param[in] norm Normalization mode. 0 - backward, 1 - forward.
63+
* @param[in] norm Normalization mode. 0 - backward, 1 - forward, 2 - ortho.
6464
*/
6565
template <typename _DataType>
6666
INP_DLLEXPORT void dpnp_fft_fft_c(const void* array_in,

dpnp/fft/dpnp_iface_fft.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
from dpnp.dpnp_utils import *
4747
from dpnp.fft.dpnp_algo_fft import *
48+
from enum import Enum
4849

4950

5051
__all__ = [
@@ -69,6 +70,21 @@
6970
]
7071

7172

73+
class Norm(Enum):
74+
backward = 0
75+
forward = 1
76+
ortho = 2
77+
78+
def get_validated_norm(norm):
79+
if norm is None or norm == "backward":
80+
return Norm.backward
81+
if norm == "forward":
82+
return Norm.forward
83+
if norm == "ortho":
84+
return Norm.ortho
85+
raise ValueError("Unknown norm value.")
86+
87+
7288
def fft(x1, n=None, axis=-1, norm=None):
7389
"""
7490
Compute the one-dimensional discrete Fourier Transform.
@@ -86,10 +102,8 @@ def fft(x1, n=None, axis=-1, norm=None):
86102

87103
x1_desc = dpnp.get_dpnp_descriptor(x1)
88104
if x1_desc:
89-
# if norm is None or norm is 'backward':
90-
# norm_val = 0
91-
# else:
92-
# norm_val = 1
105+
norm_ = get_validated_norm(norm)
106+
93107
if axis is None:
94108
axis_param = -1 # the most right dimension (default value)
95109
else:
@@ -108,9 +122,11 @@ def fft(x1, n=None, axis=-1, norm=None):
108122
pass
109123
elif axis != -1:
110124
pass
125+
elif x1_desc.dtype not in (numpy.complex128, numpy.complex64):
126+
pass
111127
else:
112128
output_boundarie = input_boundarie
113-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, 0).get_pyobj()
129+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
114130
return call_origin(numpy.fft.fft, x1, n, axis, norm)
115131

116132

@@ -219,6 +235,9 @@ def fftshift(x1, axes=None):
219235

220236
x1_desc = dpnp.get_dpnp_descriptor(x1)
221237
if x1_desc and 0:
238+
239+
norm_= Norm.backward
240+
222241
if axis is None:
223242
axis_param = -1 # the most right dimension (default value)
224243
else:
@@ -227,7 +246,7 @@ def fftshift(x1, axes=None):
227246
if x1_desc.size < 1:
228247
pass # let fallback to handle exception
229248
else:
230-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False).get_pyobj()
249+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
231250

232251
return call_origin(numpy.fft.fftshift, x1, axes)
233252

@@ -248,6 +267,8 @@ def hfft(x1, n=None, axis=-1, norm=None):
248267

249268
x1_desc = dpnp.get_dpnp_descriptor(x1)
250269
if x1_desc and 0:
270+
norm_ = get_validated_norm(norm)
271+
251272
if axis is None:
252273
axis_param = -1 # the most right dimension (default value)
253274
else:
@@ -267,7 +288,7 @@ def hfft(x1, n=None, axis=-1, norm=None):
267288
else:
268289
output_boundarie = input_boundarie
269290

270-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False).get_pyobj()
291+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
271292

272293
return call_origin(numpy.fft.hfft, x1, n, axis, norm)
273294

@@ -287,7 +308,9 @@ def ifft(x1, n=None, axis=-1, norm=None):
287308
"""
288309

289310
x1_desc = dpnp.get_dpnp_descriptor(x1)
290-
if x1_desc:
311+
if x1_desc and 0:
312+
norm_ = get_validated_norm(norm)
313+
291314
if axis is None:
292315
axis_param = -1 # the most right dimension (default value)
293316
else:
@@ -307,7 +330,7 @@ def ifft(x1, n=None, axis=-1, norm=None):
307330
else:
308331
output_boundarie = input_boundarie
309332

310-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True).get_pyobj()
333+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj()
311334

312335
return call_origin(numpy.fft.ifft, x1, n, axis, norm)
313336

@@ -354,6 +377,9 @@ def ifftshift(x1, axes=None):
354377

355378
x1_desc = dpnp.get_dpnp_descriptor(x1)
356379
if x1_desc and 0:
380+
381+
norm_ = Norm.backward
382+
357383
if axis is None:
358384
axis_param = -1 # the most right dimension (default value)
359385
else:
@@ -362,7 +388,7 @@ def ifftshift(x1, axes=None):
362388
if x1_desc.size < 1:
363389
pass # let fallback to handle exception
364390
else:
365-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False).get_pyobj()
391+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
366392

367393
return call_origin(numpy.fft.ifftshift, x1, axes)
368394

@@ -384,7 +410,7 @@ def ifftn(x1, s=None, axes=None, norm=None):
384410
"""
385411

386412
x1_desc = dpnp.get_dpnp_descriptor(x1)
387-
if x1_desc:
413+
if x1_desc and 0:
388414
if s is None:
389415
boundaries = tuple([x1_desc.shape[i] for i in range(x1_desc.ndim)])
390416
else:
@@ -432,6 +458,8 @@ def ihfft(x1, n=None, axis=-1, norm=None):
432458

433459
x1_desc = dpnp.get_dpnp_descriptor(x1)
434460
if x1_desc and 0:
461+
norm_ = get_validated_norm(norm)
462+
435463
if axis is None:
436464
axis_param = -1 # the most right dimension (default value)
437465
else:
@@ -451,7 +479,7 @@ def ihfft(x1, n=None, axis=-1, norm=None):
451479
else:
452480
output_boundarie = input_boundarie
453481

454-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False).get_pyobj()
482+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
455483

456484
return call_origin(numpy.fft.ihfft, x1, n, axis, norm)
457485

@@ -472,6 +500,8 @@ def irfft(x1, n=None, axis=-1, norm=None):
472500

473501
x1_desc = dpnp.get_dpnp_descriptor(x1)
474502
if x1_desc and 0:
503+
norm_ = get_validated_norm(norm)
504+
475505
if axis is None:
476506
axis_param = -1 # the most right dimension (default value)
477507
else:
@@ -491,7 +521,7 @@ def irfft(x1, n=None, axis=-1, norm=None):
491521
else:
492522
output_boundarie = 2 * (input_boundarie - 1)
493523

494-
result = dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True).get_pyobj()
524+
result = dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj()
495525
# TODO tmp = utils.create_output_array(result_shape, result_c_type, out)
496526
# tmp = dparray(result.shape, dtype=dpnp.float64)
497527
# for it in range(tmp.size):
@@ -592,6 +622,8 @@ def rfft(x1, n=None, axis=-1, norm=None):
592622

593623
x1_desc = dpnp.get_dpnp_descriptor(x1)
594624
if x1_desc:
625+
norm_ = get_validated_norm(norm)
626+
595627
if axis is None:
596628
axis_param = -1 # the most right dimension (default value)
597629
else:
@@ -608,10 +640,14 @@ def rfft(x1, n=None, axis=-1, norm=None):
608640
pass # let fallback to handle exception
609641
elif norm is not None:
610642
pass
643+
elif x1_desc.ndim > 1:
644+
pass
645+
elif x1_desc.dtype not in (numpy.complex128, numpy.complex64):
646+
pass
611647
else:
612648
output_boundarie = input_boundarie // 2 + 1 # rfft specific requirenment
613649

614-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False).get_pyobj()
650+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, norm_.value).get_pyobj()
615651

616652
return call_origin(numpy.fft.rfft, x1, n, axis, norm)
617653

@@ -674,7 +710,7 @@ def rfftn(x1, s=None, axes=None, norm=None):
674710
"""
675711

676712
x1_desc = dpnp.get_dpnp_descriptor(x1)
677-
if x1_desc:
713+
if x1_desc and 0:
678714
if s is None:
679715
boundaries = tuple([x1_desc.shape[i] for i in range(x1_desc.ndim)])
680716
else:

0 commit comments

Comments
 (0)