Skip to content

Commit 07881dd

Browse files
authored
use sycl_adapter in krnl_manipulation (#911)
1 parent a7093ed commit 07881dd

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

dpnp/backend/kernels/dpnp_krnl_manipulation.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "dpnp_fptr.hpp"
3333
#include "dpnp_utils.hpp"
34+
#include "dpnpc_memory_adapter.hpp"
3435
#include "queue_sycl.hpp"
3536

3637
template <typename _DataType_dst, typename _DataType_src>
@@ -45,12 +46,7 @@ class dpnp_repeat_c_kernel;
4546
template <typename _DataType>
4647
void dpnp_repeat_c(const void* array1_in, void* result1, const size_t repeats, const size_t size)
4748
{
48-
cl::sycl::event event;
49-
50-
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
51-
_DataType* result = reinterpret_cast<_DataType*>(result1);
52-
53-
if (!array_in || !result)
49+
if (!array1_in || !result1)
5450
{
5551
return;
5652
}
@@ -60,6 +56,11 @@ void dpnp_repeat_c(const void* array1_in, void* result1, const size_t repeats, c
6056
return;
6157
}
6258

59+
cl::sycl::event event;
60+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size);
61+
const _DataType* array_in = input1_ptr.get_ptr();
62+
_DataType* result = reinterpret_cast<_DataType*>(result1);
63+
6364
cl::sycl::range<2> gws(size, repeats);
6465
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
6566
size_t idx1 = global_id[0];
@@ -94,7 +95,8 @@ void dpnp_elemwise_transpose_c(void* array1_in,
9495
}
9596

9697
cl::sycl::event event;
97-
_DataType* array1 = reinterpret_cast<_DataType*>(array1_in);
98+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size);
99+
_DataType* array1 = input1_ptr.get_ptr();
98100
_DataType* result = reinterpret_cast<_DataType*>(result1);
99101

100102
size_t* input_offset_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(long)));

0 commit comments

Comments
 (0)