Skip to content

Commit ea65198

Browse files
Support for multidim arrays in dpnp.fft.fft (#1112)
* support for multidim arrays in dpnp.fft.fft * fix comments above * add check of shape Co-authored-by: Alexander-Makaryev <alexander.makaryev@gmail.com>
1 parent 3d69185 commit ea65198

File tree

4 files changed

+61
-33
lines changed

4 files changed

+61
-33
lines changed

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,16 @@ void dpnp_fft_fft_sycl_c(const void* array1_in,
169169
template <typename _DataType_input, typename _DataType_output, typename _Descriptor_type>
170170
void dpnp_fft_fft_mathlib_compute_c(const void* array1_in,
171171
void* result1,
172+
const shape_elem_type* input_shape,
172173
const size_t shape_size,
173174
const size_t result_size,
174175
_Descriptor_type& desc,
175176
const size_t norm)
176177
{
177-
sycl::event event;
178+
if (!shape_size)
179+
{
180+
return;
181+
}
178182

179183
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, result_size);
180184
DPNPC_ptr_adapter<_DataType_output> result_ptr(result1, result_size);
@@ -187,9 +191,19 @@ void dpnp_fft_fft_mathlib_compute_c(const void* array1_in,
187191
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
188192
desc.commit(DPNP_QUEUE);
189193

190-
event = mkl_dft::compute_forward(desc, array_1, result);
194+
const size_t n_iter =
195+
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
191196

192-
event.wait();
197+
const size_t shift = input_shape[shape_size - 1];
198+
199+
std::vector<sycl::event> fft_events;
200+
fft_events.reserve(n_iter);
201+
202+
for (size_t i = 0; i < n_iter; ++i) {
203+
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift));
204+
}
205+
206+
sycl::event::wait(fft_events);
193207

194208
return;
195209
}
@@ -207,39 +221,24 @@ void dpnp_fft_fft_mathlib_c(const void* array1_in,
207221
{
208222
return;
209223
}
210-
std::vector<std::int64_t> dimensions(input_shape, input_shape + shape_size);
224+
//will be used with strides
225+
//std::vector<std::int64_t> dimensions(input_shape, input_shape + shape_size);
211226

212227
if constexpr (std::is_same<_DataType_input, std::complex<double>>::value &&
213228
std::is_same<_DataType_output, std::complex<double>>::value)
214229
{
215-
if (shape_size == 1)
216-
{
217-
desc_dp_cmplx_t desc(result_size);
218-
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
219-
array1_in, result1, shape_size, result_size, desc, norm);
220-
}
221-
else
222-
{
223-
desc_dp_cmplx_t desc(dimensions);
224-
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
225-
array1_in, result1, shape_size, result_size, desc, norm);
226-
}
230+
desc_dp_cmplx_t desc(input_shape[shape_size - 1]);
231+
232+
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
233+
array1_in, result1, input_shape, shape_size, result_size, desc, norm);
227234
}
228235
else if (std::is_same<_DataType_input, std::complex<float>>::value &&
229236
std::is_same<_DataType_output, std::complex<float>>::value)
230237
{
231-
if (shape_size == 1)
232-
{
233-
desc_sp_cmplx_t desc(result_size);
234-
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
235-
array1_in, result1, shape_size, result_size, desc, norm);
236-
}
237-
else
238-
{
239-
desc_sp_cmplx_t desc(dimensions);
240-
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
241-
array1_in, result1, shape_size, result_size, desc, norm);
242-
}
238+
desc_sp_cmplx_t desc(input_shape[shape_size - 1]);
239+
240+
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
241+
array1_in, result1, input_shape, shape_size, result_size, desc, norm);
243242
}
244243
return;
245244
}
@@ -270,11 +269,10 @@ void dpnp_fft_fft_c(const void* array1_in,
270269
return;
271270
}
272271

273-
if (((std::is_same<_DataType_input, std::complex<double>>::value &&
272+
if ((std::is_same<_DataType_input, std::complex<double>>::value &&
274273
std::is_same<_DataType_output, std::complex<double>>::value) ||
275274
(std::is_same<_DataType_input, std::complex<float>>::value &&
276-
std::is_same<_DataType_output, std::complex<float>>::value)) &&
277-
(shape_size <= 3))
275+
std::is_same<_DataType_output, std::complex<float>>::value))
278276
{
279277
dpnp_fft_fft_mathlib_c<_DataType_input, _DataType_output>(
280278
array1_in, result1, input_shape, shape_size, result_size, norm);

dpnp/fft/dpnp_iface_fft.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def fft(x1, n=None, axis=-1, norm=None):
7676
Limitations
7777
-----------
7878
Parameter ``norm`` is unsupported.
79+
Parameter ``axis`` is supported with its default value.
7980
Parameter ``x1`` supports ``dpnp.int32``, ``dpnp.int64``, ``dpnp.float32``, ``dpnp.float64``,
8081
``dpnp.complex64`` and ``dpnp.complex128`` datatypes only.
8182
@@ -105,11 +106,11 @@ def fft(x1, n=None, axis=-1, norm=None):
105106
pass # let fallback to handle exception
106107
elif norm is not None:
107108
pass
109+
elif axis != -1:
110+
pass
108111
else:
109112
output_boundarie = input_boundarie
110-
111113
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, 0).get_pyobj()
112-
113114
return call_origin(numpy.fft.fft, x1, n, axis, norm)
114115

115116

tests/skipped_tests_gpu.tbl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ tests/test_fft.py::test_fft[float32]
217217
tests/test_fft.py::test_fft[float64]
218218
tests/test_fft.py::test_fft[int32]
219219
tests/test_fft.py::test_fft[int64]
220+
tests/test_fft.py::test_fft_ndim[shape0-float32]
221+
tests/test_fft.py::test_fft_ndim[shape0-float64]
222+
tests/test_fft.py::test_fft_ndim[shape0-int32]
223+
tests/test_fft.py::test_fft_ndim[shape0-int64]
224+
tests/test_fft.py::test_fft_ndim[shape1-float32]
225+
tests/test_fft.py::test_fft_ndim[shape1-float64]
226+
tests/test_fft.py::test_fft_ndim[shape1-int32]
227+
tests/test_fft.py::test_fft_ndim[shape1-int64]
228+
tests/test_fft.py::test_fft_ndim[shape2-float32]
229+
tests/test_fft.py::test_fft_ndim[shape2-float64]
230+
tests/test_fft.py::test_fft_ndim[shape2-int32]
231+
tests/test_fft.py::test_fft_ndim[shape2-int64]
232+
tests/test_fft.py::test_fft_ndim[shape3-float32]
233+
tests/test_fft.py::test_fft_ndim[shape3-float64]
234+
tests/test_fft.py::test_fft_ndim[shape3-int32]
235+
tests/test_fft.py::test_fft_ndim[shape3-int64]
220236
tests/test_linalg.py::test_cond[-1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
221237
tests/test_linalg.py::test_cond[1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
222238
tests/test_linalg.py::test_cond[-2-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]

tests/test_fft.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,16 @@ def test_fft(type):
1919

2020
numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
2121
assert dpnp_res.dtype == np_res.dtype
22+
23+
24+
@pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64'])
25+
@pytest.mark.parametrize("shape", [(8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)])
26+
def test_fft_ndim(type, shape):
27+
np_data = numpy.arange(64, dtype=numpy.dtype(type)).reshape(shape)
28+
dpnp_data = dpnp.arange(64, dtype=numpy.dtype(type)).reshape(shape)
29+
30+
np_res = numpy.fft.fft(np_data)
31+
dpnp_res = dpnp.fft.fft(dpnp_data)
32+
33+
numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
34+
assert dpnp_res.dtype == np_res.dtype

0 commit comments

Comments
 (0)