Skip to content

Commit e567e33

Browse files
authored
use sycl_adapter in krnl_elemwise (#896)
1 parent 2d76330 commit e567e33

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
#include <iostream>
2828

2929
#include <dpnp_iface.hpp>
30+
3031
#include "dpnp_fptr.hpp"
3132
#include "dpnp_iterator.hpp"
3233
#include "dpnp_utils.hpp"
34+
#include "dpnpc_memory_adapter.hpp"
3335
#include "queue_sycl.hpp"
3436

3537
#define MACRO_1ARG_2TYPES_OP(__name__, __operation1__, __operation2__) \
@@ -40,8 +42,9 @@
4042
void __name__(void* array1_in, void* result1, size_t size) \
4143
{ \
4244
cl::sycl::event event; \
45+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, size); \
4346
\
44-
_DataType_input* array1 = reinterpret_cast<_DataType_input*>(array1_in); \
47+
_DataType_input* array1 = input1_ptr.get_ptr(); \
4548
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1); \
4649
\
4750
cl::sycl::range<1> gws(size); \
@@ -259,7 +262,8 @@ static void func_map_init_elemwise_1arg_2type(func_map_t& fmap)
259262
return; \
260263
} \
261264
\
262-
_DataType* array1 = reinterpret_cast<_DataType*>(array1_in); \
265+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size); \
266+
_DataType* array1 = input1_ptr.get_ptr(); \
263267
_DataType* result = reinterpret_cast<_DataType*>(result1); \
264268
\
265269
cl::sycl::range<1> gws(size); \
@@ -366,8 +370,10 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
366370
return; \
367371
} \
368372
\
369-
_DataType_input1* input1_data = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in)); \
370-
_DataType_input2* input2_data = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in)); \
373+
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size); \
374+
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size); \
375+
_DataType_input1* input1_data = input1_ptr.get_ptr(); \
376+
_DataType_input2* input2_data = input2_ptr.get_ptr(); \
371377
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out); \
372378
\
373379
std::vector<size_t> result_shape = \

0 commit comments

Comments
 (0)