Skip to content

Commit 2cfe4a3

Browse files
authored
use sycl_adapter in krnl_fft (#901)
1 parent e205cae commit 2cfe4a3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <dpnp_iface.hpp>
2929
#include "dpnp_fptr.hpp"
3030
#include "dpnp_utils.hpp"
31+
#include "dpnpc_memory_adapter.hpp"
3132
#include "queue_sycl.hpp"
3233

3334
namespace mkl_dft = oneapi::mkl::dft;
@@ -51,16 +52,18 @@ void dpnp_fft_fft_c(const void* array1_in,
5152
long input_boundarie,
5253
size_t inverse)
5354
{
55+
const size_t input_size = std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<size_t>());
5456
const size_t result_size = std::accumulate(output_shape, output_shape + shape_size, 1, std::multiplies<size_t>());
55-
if (!(result_size && shape_size))
57+
if (!(input_size && result_size && shape_size))
5658
{
5759
return;
5860
}
5961

6062
cl::sycl::event event;
6163
const double kernel_pi = inverse ? -M_PI : M_PI;
6264

63-
const _DataType_input* array_1 = reinterpret_cast<const _DataType_input*>(array1_in);
65+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, input_size);
66+
const _DataType_input* array_1 = input1_ptr.get_ptr();
6467
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1);
6568

6669
// kernel specific temporal data

0 commit comments

Comments
 (0)