Skip to content

Commit e9e19f0

Browse files
authored
use sycl_adapter in krnl_arraycreation (#898)
1 parent e567e33 commit e9e19f0

File tree

1 file changed

+35
-20
lines changed

1 file changed

+35
-20
lines changed

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include "dpnp_fptr.hpp"
2929
#include "dpnp_iface.hpp"
30+
#include "dpnpc_memory_adapter.hpp"
3031
#include "queue_sycl.hpp"
3132

3233
template <typename _KernelNameSpecialization>
@@ -70,8 +71,12 @@ void dpnp_diag_c(
7071
// avoid warning unused variable
7172
(void)res_ndim;
7273

73-
_DataType* v = reinterpret_cast<_DataType*>(v_in);
74-
_DataType* result = reinterpret_cast<_DataType*>(result1);
74+
const size_t input1_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
75+
const size_t result_size = std::accumulate(res_shape, res_shape + res_ndim, 1, std::multiplies<size_t>());
76+
DPNPC_ptr_adapter<_DataType> input1_ptr(v_in, input1_size, true);
77+
DPNPC_ptr_adapter<_DataType> result_ptr(result1, result_size, true, true);
78+
_DataType* v = input1_ptr.get_ptr();
79+
_DataType* result = result_ptr.get_ptr();
7580

7681
size_t init0 = std::max(0, -k);
7782
size_t init1 = std::max(0, k);
@@ -167,8 +172,10 @@ void dpnp_vander_c(const void* array1_in, void* result1, const size_t size_in, c
167172
if (!size_in || !N)
168173
return;
169174

170-
const _DataType_input* array_in = reinterpret_cast<const _DataType_input*>(array1_in);
171-
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1);
175+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, size_in, true);
176+
DPNPC_ptr_adapter<_DataType_output> result_ptr(result1, size_in * N, true, true);
177+
const _DataType_input* array_in = input1_ptr.get_ptr();
178+
_DataType_output* result = result_ptr.get_ptr();
172179

173180
if (N == 1)
174181
{
@@ -222,16 +229,17 @@ void dpnp_trace_c(const void* array1_in, void* result_in, const size_t* shape_,
222229
return;
223230
}
224231

225-
const _DataType* input = reinterpret_cast<const _DataType*>(array1_in);
226-
_ResultType* result = reinterpret_cast<_ResultType*>(result_in);
227232
const size_t last_dim = shape_[ndim - 1];
228-
229233
const size_t size = std::accumulate(shape_, shape_ + (ndim - 1), 1, std::multiplies<size_t>());
230234
if (!size)
231235
{
232236
return;
233237
}
234238

239+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size * last_dim);
240+
const _DataType* input = input1_ptr.get_ptr();
241+
_ResultType* result = reinterpret_cast<_ResultType*>(result_in);
242+
235243
cl::sycl::range<1> gws(size);
236244
auto kernel_parallel_for_func = [=](auto index) {
237245
size_t i = index[0];
@@ -312,9 +320,6 @@ void dpnp_tril_c(void* array_in,
312320
return;
313321
}
314322

315-
_DataType* array_m = reinterpret_cast<_DataType*>(array_in);
316-
_DataType* result = reinterpret_cast<_DataType*>(result1);
317-
318323
if ((shape == nullptr) || (res_shape == nullptr))
319324
{
320325
return;
@@ -325,17 +330,23 @@ void dpnp_tril_c(void* array_in,
325330
return;
326331
}
327332

328-
size_t res_size = 1;
329-
for (size_t i = 0; i < res_ndim; ++i)
333+
const size_t res_size = std::accumulate(res_shape, res_shape + res_ndim, 1, std::multiplies<size_t>());
334+
if (res_size == 0)
330335
{
331-
res_size *= res_shape[i];
336+
return;
332337
}
333338

334-
if (res_size == 0)
339+
const size_t input_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
340+
if (input_size == 0)
335341
{
336342
return;
337343
}
338344

345+
DPNPC_ptr_adapter<_DataType> input1_ptr(array_in, input_size, true);
346+
DPNPC_ptr_adapter<_DataType> result_ptr(result1, res_size, true, true);
347+
_DataType* array_m = input1_ptr.get_ptr();
348+
_DataType* result = result_ptr.get_ptr();
349+
339350
if (ndim == 1)
340351
{
341352
for (size_t i = 0; i < res_size; ++i)
@@ -416,8 +427,6 @@ void dpnp_triu_c(void* array_in,
416427
{
417428
return;
418429
}
419-
_DataType* array_m = reinterpret_cast<_DataType*>(array_in);
420-
_DataType* result = reinterpret_cast<_DataType*>(result1);
421430

422431
if ((shape == nullptr) || (res_shape == nullptr))
423432
{
@@ -429,17 +438,23 @@ void dpnp_triu_c(void* array_in,
429438
return;
430439
}
431440

432-
size_t res_size = 1;
433-
for (size_t i = 0; i < res_ndim; ++i)
441+
const size_t res_size = std::accumulate(res_shape, res_shape + res_ndim, 1, std::multiplies<size_t>());
442+
if (res_size == 0)
434443
{
435-
res_size *= res_shape[i];
444+
return;
436445
}
437446

438-
if (res_size == 0)
447+
const size_t input_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
448+
if (input_size == 0)
439449
{
440450
return;
441451
}
442452

453+
DPNPC_ptr_adapter<_DataType> input1_ptr(array_in, input_size, true);
454+
DPNPC_ptr_adapter<_DataType> result_ptr(result1, res_size, true, true);
455+
_DataType* array_m = input1_ptr.get_ptr();
456+
_DataType* result = result_ptr.get_ptr();
457+
443458
if (ndim == 1)
444459
{
445460
for (size_t i = 0; i < res_size; ++i)

0 commit comments

Comments
 (0)