Skip to content

Commit e205cae

Browse files
authored
use sycl_adapter in krnl_common (#900)
* use sycl_adapter in krnl_common
1 parent ce0851b commit e205cae

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <dpnp_iface.hpp>
3131
#include "dpnp_fptr.hpp"
3232
#include "dpnp_utils.hpp"
33+
#include "dpnpc_memory_adapter.hpp"
3334
#include "queue_sycl.hpp"
3435

3536
namespace mkl_blas = oneapi::mkl::blas;
@@ -42,8 +43,8 @@ template <typename _DataType, typename _ResultType>
4243
void dpnp_astype_c(const void* array1_in, void* result1, const size_t size)
4344
{
4445
cl::sycl::event event;
45-
46-
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
46+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size);
47+
const _DataType* array_in = input1_ptr.get_ptr();
4748
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
4849

4950
if ((array_in == nullptr) || (result == nullptr))
@@ -88,14 +89,16 @@ void dpnp_dot_c(void* result_out,
8889
{
8990
(void)input1_shape;
9091
(void)input1_shape_ndim;
91-
(void)input2_size;
9292
(void)input2_shape;
9393
(void)input2_shape_ndim;
9494
(void)where;
9595

9696
cl::sycl::event event;
97-
_DataType_input1* input1 = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in));
98-
_DataType_input2* input2 = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in));
97+
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size);
98+
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size);
99+
100+
_DataType_input1* input1 = input1_ptr.get_ptr();
101+
_DataType_input2* input2 = input2_ptr.get_ptr();
99102
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
100103

101104
if (!input1_size)
@@ -146,7 +149,7 @@ void dpnp_dot_c(void* result_out,
146149
std::reduce(policy, local_mem, local_mem + input1_size, _DataType_output(0), std::plus<_DataType_output>());
147150
policy.queue().wait();
148151

149-
result[0] = accumulator;
152+
result[0] = accumulator; // TODO use memcpy_c
150153

151154
free(local_mem, DPNP_QUEUE);
152155
}
@@ -166,8 +169,8 @@ void dpnp_eig_c(const void* array_in, void* result1, void* result2, size_t size)
166169
}
167170

168171
cl::sycl::event event;
169-
170-
const _DataType* array = reinterpret_cast<const _DataType*>(array_in);
172+
DPNPC_ptr_adapter<_DataType> input1_ptr(array_in, size);
173+
const _DataType* array = input1_ptr.get_ptr();
171174
_ResultType* result_val = reinterpret_cast<_ResultType*>(result1);
172175
_ResultType* result_vec = reinterpret_cast<_ResultType*>(result2);
173176

@@ -177,7 +180,7 @@ void dpnp_eig_c(const void* array_in, void* result1, void* result2, size_t size)
177180
// type conversion. Also, math library requires copy memory because override
178181
for (size_t it = 0; it < (size * size); ++it)
179182
{
180-
result_vec_kern[it] = array[it];
183+
result_vec_kern[it] = array[it]; // TODO use memcpy_c or input1_ptr(array_in, size, true)
181184
}
182185

183186
const std::int64_t lda = std::max<size_t>(1UL, size);
@@ -202,7 +205,7 @@ void dpnp_eig_c(const void* array_in, void* result1, void* result2, size_t size)
202205

203206
for (size_t it1 = 0; it1 < size; ++it1)
204207
{
205-
result_val[it1] = result_val_kern[it1];
208+
result_val[it1] = result_val_kern[it1]; // TODO use memcpy_c or dpnpc_transpose_c
206209
for (size_t it2 = 0; it2 < size; ++it2)
207210
{
208211
// copy + transpose
@@ -228,8 +231,8 @@ void dpnp_eigvals_c(const void* array_in, void* result1, size_t size)
228231
}
229232

230233
cl::sycl::event event;
231-
232-
const _DataType* array = reinterpret_cast<const _DataType*>(array_in);
234+
DPNPC_ptr_adapter<_DataType> input1_ptr(array_in, size);
235+
const _DataType* array = input1_ptr.get_ptr();
233236
_ResultType* result_val = reinterpret_cast<_ResultType*>(result1);
234237

235238
double* result_val_kern = reinterpret_cast<double*>(dpnp_memory_alloc_c(size * sizeof(double)));
@@ -238,7 +241,7 @@ void dpnp_eigvals_c(const void* array_in, void* result1, size_t size)
238241
// type conversion. Also, math library requires copy memory because override
239242
for (size_t it = 0; it < (size * size); ++it)
240243
{
241-
result_vec_kern[it] = array[it];
244+
result_vec_kern[it] = array[it]; // TODO same as previous kernel
242245
}
243246

244247
const std::int64_t lda = std::max<size_t>(1UL, size);
@@ -304,16 +307,18 @@ class dpnp_matmul_c_kernel;
304307
template <typename _DataType>
305308
void dpnp_matmul_c(void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k)
306309
{
307-
cl::sycl::event event;
308-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
309-
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
310-
_DataType* result = reinterpret_cast<_DataType*>(result1);
311-
312310
if (!size_m || !size_n || !size_k)
313311
{
314312
return;
315313
}
316314

315+
cl::sycl::event event;
316+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size_m * size_k);
317+
DPNPC_ptr_adapter<_DataType> input2_ptr(array2_in, size_k * size_n);
318+
_DataType* array_1 = input1_ptr.get_ptr();
319+
_DataType* array_2 = input2_ptr.get_ptr();
320+
_DataType* result = reinterpret_cast<_DataType*>(result1);
321+
317322
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
318323
{
319324
// using std::max for these ldx variables is required by math library

0 commit comments

Comments
 (0)