Skip to content

Commit 6325553

Browse files
authored
use sycl_adapter in krnl_reduction (#913)
1 parent 1bef06b commit 6325553

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

dpnp/backend/kernels/dpnp_krnl_reduction.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "dpnp_fptr.hpp"
3131
#include "dpnp_iterator.hpp"
3232
#include "dpnp_utils.hpp"
33+
#include "dpnpc_memory_adapter.hpp"
3334
#include "queue_sycl.hpp"
3435

3536
namespace mkl_stats = oneapi::mkl::stats;
@@ -74,7 +75,11 @@ void dpnp_sum_c(void* result_out,
7475

7576
const _DataType_output init = get_initial_value<_DataType_output>(initial, 0);
7677

77-
_DataType_input* input = get_array_ptr<_DataType_input>(input_in);
78+
const size_t input_size =
79+
std::accumulate(input_shape, input_shape + input_shape_ndim, size_t(1), std::multiplies<size_t>());
80+
81+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(input_in, input_size);
82+
_DataType_input* input = input1_ptr.get_ptr();
7883
_DataType_output* result = get_array_ptr<_DataType_output>(result_out);
7984

8085
if (!input_shape && !input_shape_ndim)
@@ -98,8 +103,6 @@ void dpnp_sum_c(void* result_out,
98103
// - float64 and float32 types only
99104
if (axes_ndim < 1)
100105
{
101-
const size_t input_size =
102-
std::accumulate(input_shape, input_shape + input_shape_ndim, size_t(1), std::multiplies<size_t>());
103106
auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1, input_size, input);
104107
cl::sycl::event event = mkl_stats::raw_sum(DPNP_QUEUE, dataset, result);
105108
event.wait();
@@ -121,7 +124,7 @@ void dpnp_sum_c(void* result_out,
121124
policy, input_it.begin(output_id), input_it.end(output_id), init, std::plus<_DataType_output>());
122125
policy.queue().wait(); // TODO move out of the loop
123126

124-
dpnp_memory_memcpy_c(&(result[output_id]), &accumulator, sizeof(_DataType_output)); // result[output_id] = accumulator;
127+
dpnp_memory_memcpy_c(result + output_id, &accumulator, sizeof(_DataType_output)); // result[output_id] = accumulator;
125128
}
126129

127130
return;
@@ -149,7 +152,11 @@ void dpnp_prod_c(void* result_out,
149152

150153
const _DataType_output init = get_initial_value<_DataType_output>(initial, 1);
151154

152-
_DataType_input* input = get_array_ptr<_DataType_input>(input_in);
155+
const size_t input_size =
156+
std::accumulate(input_shape, input_shape + input_shape_ndim, size_t(1), std::multiplies<size_t>());
157+
158+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(input_in, input_size);
159+
_DataType_input* input = input1_ptr.get_ptr();
153160
_DataType_output* result = get_array_ptr<_DataType_output>(result_out);
154161

155162
if (!input_shape && !input_shape_ndim)
@@ -177,7 +184,7 @@ void dpnp_prod_c(void* result_out,
177184
policy, input_it.begin(output_id), input_it.end(output_id), init, std::multiplies<_DataType_output>());
178185
policy.queue().wait(); // TODO move out of the loop
179186

180-
dpnp_memory_memcpy_c(&(result[output_id]), &accumulator, sizeof(_DataType_output)); // result[output_id] = accumulator;
187+
dpnp_memory_memcpy_c(result + output_id, &accumulator, sizeof(_DataType_output)); // result[output_id] = accumulator;
181188
}
182189

183190
return;

0 commit comments

Comments
 (0)