Skip to content

Commit a7093ed

Browse files
authored
use sycl_adapter in krnl_linalg (#909)
1 parent 4026ca2 commit a7093ed

File tree

1 file changed

+71
-28
lines changed

1 file changed

+71
-28
lines changed

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <dpnp_iface.hpp>
3030
#include "dpnp_fptr.hpp"
3131
#include "dpnp_utils.hpp"
32+
#include "dpnpc_memory_adapter.hpp"
3233
#include "queue_sycl.hpp"
3334

3435
namespace mkl_blas = oneapi::mkl::blas::row_major;
@@ -39,8 +40,10 @@ void dpnp_cholesky_c(void* array1_in, void* result1, const size_t size, const si
3940
{
4041
cl::sycl::event event;
4142

42-
_DataType* in_array = reinterpret_cast<_DataType*>(array1_in);
43-
_DataType* result = reinterpret_cast<_DataType*>(result1);
43+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size, true);
44+
DPNPC_ptr_adapter<_DataType> result_ptr(result1, size, true, true);
45+
_DataType* in_array = input1_ptr.get_ptr();
46+
_DataType* result = result_ptr.get_ptr();
4447

4548
size_t iters = size / (data_size * data_size);
4649

@@ -97,8 +100,11 @@ void dpnp_cholesky_c(void* array1_in, void* result1, const size_t size, const si
97100
template <typename _DataType>
98101
void dpnp_det_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
99102
{
100-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
101-
_DataType* result = reinterpret_cast<_DataType*>(result1);
103+
const size_t input_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
104+
if (!input_size)
105+
{
106+
return;
107+
}
102108

103109
size_t n = shape[ndim - 1];
104110
size_t size_out = 1;
@@ -110,6 +116,11 @@ void dpnp_det_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
110116
}
111117
}
112118

119+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, input_size, true);
120+
DPNPC_ptr_adapter<_DataType> result_ptr(result1, size_out, true, true);
121+
_DataType* array_1 = input1_ptr.get_ptr();
122+
_DataType* result = result_ptr.get_ptr();
123+
113124
for (size_t i = 0; i < size_out; i++)
114125
{
115126
_DataType matrix[n][n];
@@ -194,8 +205,17 @@ template <typename _DataType, typename _ResultType>
194205
void dpnp_inv_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
195206
{
196207
(void)ndim; // avoid warning unused variable
197-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
198-
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
208+
209+
const size_t input_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
210+
if (!input_size)
211+
{
212+
return;
213+
}
214+
215+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, input_size, true);
216+
DPNPC_ptr_adapter<_ResultType> result_ptr(result1, input_size, true, true);
217+
_DataType* array_1 = input1_ptr.get_ptr();
218+
_ResultType* result = result_ptr.get_ptr();
199219

200220
size_t n = shape[0];
201221

@@ -298,16 +318,21 @@ void dpnp_kron_c(void* array1_in,
298318
size_t* res_shape,
299319
size_t ndim)
300320
{
301-
_DataType1* array1 = reinterpret_cast<_DataType1*>(array1_in);
302-
_DataType2* array2 = reinterpret_cast<_DataType2*>(array2_in);
303-
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
304-
305-
size_t size = 1;
306-
for (size_t i = 0; i < ndim; ++i)
321+
const size_t input1_size = std::accumulate(in1_shape, in1_shape + ndim, 1, std::multiplies<size_t>());
322+
const size_t input2_size = std::accumulate(in2_shape, in2_shape + ndim, 1, std::multiplies<size_t>());
323+
const size_t result_size = std::accumulate(res_shape, res_shape + ndim, 1, std::multiplies<size_t>());
324+
if (!(result_size && input1_size && input2_size))
307325
{
308-
size *= res_shape[i];
326+
return;
309327
}
310328

329+
DPNPC_ptr_adapter<_DataType1> input1_ptr(array1_in, input1_size);
330+
DPNPC_ptr_adapter<_DataType2> input2_ptr(array2_in, input2_size);
331+
DPNPC_ptr_adapter<_ResultType> result_ptr(result1, result_size);
332+
_DataType1* array1 = input1_ptr.get_ptr();
333+
_DataType2* array2 = input2_ptr.get_ptr();
334+
_ResultType* result = result_ptr.get_ptr();
335+
311336
size_t* _in1_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
312337
size_t* _in2_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
313338

@@ -322,7 +347,7 @@ void dpnp_kron_c(void* array1_in,
322347
get_shape_offsets_inkernel<size_t>(in2_shape, ndim, in2_offsets);
323348
get_shape_offsets_inkernel<size_t>(res_shape, ndim, res_offsets);
324349

325-
cl::sycl::range<1> gws(size);
350+
cl::sycl::range<1> gws(result_size);
326351
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
327352
const size_t idx = global_id[0];
328353

@@ -356,12 +381,18 @@ void dpnp_kron_c(void* array1_in,
356381
template <typename _DataType>
357382
void dpnp_matrix_rank_c(void* array1_in, void* result1, size_t* shape, size_t ndim)
358383
{
359-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
360-
_DataType* result = reinterpret_cast<_DataType*>(result1);
384+
const size_t input_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<size_t>());
385+
if (!input_size)
386+
{
387+
return;
388+
}
389+
390+
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, input_size);
391+
DPNPC_ptr_adapter<_DataType> result_ptr(result1, 1);
392+
_DataType* array_1 = input1_ptr.get_ptr();
393+
_DataType* result = result_ptr.get_ptr();
361394

362395
size_t elems = 1;
363-
const _DataType init_val = 0;
364-
dpnp_memory_memcpy_c(result, &init_val, sizeof(_DataType)); // result[0] = 0;
365396
if (ndim > 1)
366397
{
367398
elems = shape[0];
@@ -373,15 +404,18 @@ void dpnp_matrix_rank_c(void* array1_in, void* result1, size_t* shape, size_t nd
373404
}
374405
}
375406
}
407+
408+
_DataType acc = 0;
376409
for (size_t i = 0; i < elems; i++)
377410
{
378411
size_t ind = 0;
379412
for (size_t j = 0; j < ndim; j++)
380413
{
381414
ind += (shape[j] - 1) * i;
382415
}
383-
result[0] += array_1[ind];
416+
acc += array_1[ind];
384417
}
418+
result[0] = acc;
385419

386420
return;
387421
}
@@ -391,7 +425,8 @@ void dpnp_qr_c(void* array1_in, void* result1, void* result2, void* result3, siz
391425
{
392426
cl::sycl::event event;
393427

394-
_InputDT* in_array = reinterpret_cast<_InputDT*>(array1_in);
428+
DPNPC_ptr_adapter<_InputDT> input1_ptr(array1_in, size_m * size_n, true);
429+
_InputDT* in_array = input1_ptr.get_ptr();
395430

396431
// math lib func overrides input
397432
_ComputeDT* in_a = reinterpret_cast<_ComputeDT*>(dpnp_memory_alloc_c(size_m * size_n * sizeof(_ComputeDT)));
@@ -400,13 +435,17 @@ void dpnp_qr_c(void* array1_in, void* result1, void* result2, void* result3, siz
400435
{
401436
for (size_t j = 0; j < size_n; ++j)
402437
{
438+
// TODO transpose? use dpnp_transpose_c()
403439
in_a[j * size_m + i] = in_array[i * size_n + j];
404440
}
405441
}
406442

407-
_ComputeDT* res_q = reinterpret_cast<_ComputeDT*>(result1);
408-
_ComputeDT* res_r = reinterpret_cast<_ComputeDT*>(result2);
409-
_ComputeDT* tau = reinterpret_cast<_ComputeDT*>(result3);
443+
DPNPC_ptr_adapter<_ComputeDT> result1_ptr(result1, size_m * size_m, true, true);
444+
DPNPC_ptr_adapter<_ComputeDT> result2_ptr(result2, size_m * size_n, true, true);
445+
DPNPC_ptr_adapter<_ComputeDT> result3_ptr(result3, std::min(size_m, size_n), true, true);
446+
_ComputeDT* res_q = result1_ptr.get_ptr();
447+
_ComputeDT* res_r = result2_ptr.get_ptr();
448+
_ComputeDT* tau = result3_ptr.get_ptr();
410449

411450
const std::int64_t lda = size_m;
412451

@@ -487,18 +526,22 @@ void dpnp_svd_c(void* array1_in, void* result1, void* result2, void* result3, si
487526
{
488527
cl::sycl::event event;
489528

490-
_InputDT* in_array = reinterpret_cast<_InputDT*>(array1_in);
529+
DPNPC_ptr_adapter<_InputDT> input1_ptr(array1_in, size_m * size_n, true); // TODO no need this if use dpnp_copy_to()
530+
_InputDT* in_array = input1_ptr.get_ptr();
491531

492532
// math lib gesvd func overrides input
493533
_ComputeDT* in_a = reinterpret_cast<_ComputeDT*>(dpnp_memory_alloc_c(size_m * size_n * sizeof(_ComputeDT)));
494534
for (size_t it = 0; it < size_m * size_n; ++it)
495535
{
496-
in_a[it] = in_array[it];
536+
in_a[it] = in_array[it]; // TODO Type conversion. memcpy can not be used directly. dpnp_copy_to() ?
497537
}
498538

499-
_ComputeDT* res_u = reinterpret_cast<_ComputeDT*>(result1);
500-
_SVDT* res_s = reinterpret_cast<_SVDT*>(result2);
501-
_ComputeDT* res_vt = reinterpret_cast<_ComputeDT*>(result3);
539+
DPNPC_ptr_adapter<_ComputeDT> result1_ptr(result1, size_m * size_m, true, true);
540+
DPNPC_ptr_adapter<_SVDT> result2_ptr(result2, std::min(size_m, size_n), true, true);
541+
DPNPC_ptr_adapter<_ComputeDT> result3_ptr(result3, size_n * size_n, true, true);
542+
_ComputeDT* res_u = result1_ptr.get_ptr();
543+
_SVDT* res_s = result2_ptr.get_ptr();
544+
_ComputeDT* res_vt = result3_ptr.get_ptr();
502545

503546
const std::int64_t m = size_m;
504547
const std::int64_t n = size_n;

0 commit comments

Comments
 (0)