Skip to content

Commit 5e876eb

Browse files
authored
simplify trace() kernel (#808)
1 parent 0d592e7 commit 5e876eb

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -215,49 +215,42 @@ template <typename _DataType, typename _ResultType>
215215
class dpnp_trace_c_kernel;
216216

217217
template <typename _DataType, typename _ResultType>
218-
void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape_, const size_t ndim)
218+
void dpnp_trace_c(const void* array1_in, void* result_in, const size_t* shape_, const size_t ndim)
219219
{
220-
if ((array1_in == nullptr) || (result1 == nullptr) || (shape_ == nullptr) || (ndim == 0))
220+
if (!array1_in || !result_in || !shape_ || !ndim)
221221
{
222222
return;
223223
}
224224

225-
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
226-
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
225+
const _DataType* input = reinterpret_cast<const _DataType*>(array1_in);
226+
_ResultType* result = reinterpret_cast<_ResultType*>(result_in);
227+
const size_t last_dim = shape_[ndim - 1];
227228

228-
size_t size = 1;
229-
for (size_t i = 0; i < ndim - 1; ++i)
230-
{
231-
size *= shape_[i];
232-
}
233-
234-
if (size == 0)
229+
const size_t size = std::accumulate(shape_, shape_ + (ndim - 1), 1, std::multiplies<size_t>());
230+
if (!size)
235231
{
236232
return;
237233
}
238234

239-
size_t* shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
240-
auto memcpy_event = DPNP_QUEUE.memcpy(shape, shape_, ndim * sizeof(size_t));
241-
242235
cl::sycl::range<1> gws(size);
243236
auto kernel_parallel_for_func = [=](auto index) {
244237
size_t i = index[0];
245-
result[i] = 0;
246-
for (size_t j = 0; j < shape[ndim - 1]; ++j)
238+
_ResultType acc = _ResultType(0);
239+
240+
for (size_t j = 0; j < last_dim; ++j)
247241
{
248-
result[i] += array_in[i * shape[ndim - 1] + j];
242+
acc += input[i * last_dim + j];
249243
}
244+
245+
result[i] = acc;
250246
};
251247

252248
auto kernel_func = [&](cl::sycl::handler& cgh) {
253-
cgh.depends_on({memcpy_event});
254249
cgh.parallel_for<class dpnp_trace_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
255250
};
256251

257252
auto event = DPNP_QUEUE.submit(kernel_func);
258253
event.wait();
259-
260-
dpnp_memory_free_c(shape);
261254
}
262255

263256
template <typename _DataType>

0 commit comments

Comments
 (0)