@@ -169,12 +169,16 @@ void dpnp_fft_fft_sycl_c(const void* array1_in,
169
169
template <typename _DataType_input, typename _DataType_output, typename _Descriptor_type>
170
170
void dpnp_fft_fft_mathlib_compute_c (const void * array1_in,
171
171
void * result1,
172
+ const shape_elem_type* input_shape,
172
173
const size_t shape_size,
173
174
const size_t result_size,
174
175
_Descriptor_type& desc,
175
176
const size_t norm)
176
177
{
177
- sycl::event event;
178
+ if (!shape_size)
179
+ {
180
+ return ;
181
+ }
178
182
179
183
DPNPC_ptr_adapter<_DataType_input> input1_ptr (array1_in, result_size);
180
184
DPNPC_ptr_adapter<_DataType_output> result_ptr (result1, result_size);
@@ -187,9 +191,19 @@ void dpnp_fft_fft_mathlib_compute_c(const void* array1_in,
187
191
desc.set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
188
192
desc.commit (DPNP_QUEUE);
189
193
190
- event = mkl_dft::compute_forward (desc, array_1, result);
194
+ const size_t n_iter =
195
+ std::accumulate (input_shape, input_shape + shape_size - 1 , 1 , std::multiplies<shape_elem_type>());
191
196
192
- event.wait ();
197
+ const size_t shift = input_shape[shape_size - 1 ];
198
+
199
+ std::vector<sycl::event> fft_events;
200
+ fft_events.reserve (n_iter);
201
+
202
+ for (size_t i = 0 ; i < n_iter; ++i) {
203
+ fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * shift, result + i * shift));
204
+ }
205
+
206
+ sycl::event::wait (fft_events);
193
207
194
208
return ;
195
209
}
@@ -207,39 +221,24 @@ void dpnp_fft_fft_mathlib_c(const void* array1_in,
207
221
{
208
222
return ;
209
223
}
210
- std::vector<std::int64_t > dimensions (input_shape, input_shape + shape_size);
224
+ // will be used with strides
225
+ // std::vector<std::int64_t> dimensions(input_shape, input_shape + shape_size);
211
226
212
227
if constexpr (std::is_same<_DataType_input, std::complex<double >>::value &&
213
228
std::is_same<_DataType_output, std::complex<double >>::value)
214
229
{
215
- if (shape_size == 1 )
216
- {
217
- desc_dp_cmplx_t desc (result_size);
218
- dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
219
- array1_in, result1, shape_size, result_size, desc, norm);
220
- }
221
- else
222
- {
223
- desc_dp_cmplx_t desc (dimensions);
224
- dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
225
- array1_in, result1, shape_size, result_size, desc, norm);
226
- }
230
+ desc_dp_cmplx_t desc (input_shape[shape_size - 1 ]);
231
+
232
+ dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
233
+ array1_in, result1, input_shape, shape_size, result_size, desc, norm);
227
234
}
228
235
else if (std::is_same<_DataType_input, std::complex<float >>::value &&
229
236
std::is_same<_DataType_output, std::complex<float >>::value)
230
237
{
231
- if (shape_size == 1 )
232
- {
233
- desc_sp_cmplx_t desc (result_size);
234
- dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
235
- array1_in, result1, shape_size, result_size, desc, norm);
236
- }
237
- else
238
- {
239
- desc_sp_cmplx_t desc (dimensions);
240
- dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
241
- array1_in, result1, shape_size, result_size, desc, norm);
242
- }
238
+ desc_sp_cmplx_t desc (input_shape[shape_size - 1 ]);
239
+
240
+ dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
241
+ array1_in, result1, input_shape, shape_size, result_size, desc, norm);
243
242
}
244
243
return ;
245
244
}
@@ -270,11 +269,10 @@ void dpnp_fft_fft_c(const void* array1_in,
270
269
return ;
271
270
}
272
271
273
- if ((( std::is_same<_DataType_input, std::complex<double >>::value &&
272
+ if ((std::is_same<_DataType_input, std::complex<double >>::value &&
274
273
std::is_same<_DataType_output, std::complex<double >>::value) ||
275
274
(std::is_same<_DataType_input, std::complex<float >>::value &&
276
- std::is_same<_DataType_output, std::complex<float >>::value)) &&
277
- (shape_size <= 3 ))
275
+ std::is_same<_DataType_output, std::complex<float >>::value))
278
276
{
279
277
dpnp_fft_fft_mathlib_c<_DataType_input, _DataType_output>(
280
278
array1_in, result1, input_shape, shape_size, result_size, norm);
0 commit comments