Skip to content

Commit 2968c21

Browse files
authored
fix dpnp.trace() (#701)
* fix dpnp.trace()
1 parent b574cc5 commit 2968c21

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ template <typename _DataType, typename _ResultType>
215215
class dpnp_trace_c_kernel;
216216

217217
template <typename _DataType, typename _ResultType>
218-
void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape, const size_t ndim)
218+
void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape_, const size_t ndim)
219219
{
220220
cl::sycl::event event;
221221

@@ -227,7 +227,7 @@ void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape, con
227227
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
228228
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
229229

230-
if (shape == nullptr)
230+
if (shape_ == nullptr)
231231
{
232232
return;
233233
}
@@ -240,32 +240,37 @@ void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape, con
240240
size_t size = 1;
241241
for (size_t i = 0; i < ndim - 1; ++i)
242242
{
243-
size *= shape[i];
243+
size *= shape_[i];
244244
}
245245

246246
if (size == 0)
247247
{
248248
return;
249249
}
250250

251+
size_t* shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(size_t)));
252+
auto memcpy_event = DPNP_QUEUE.memcpy(shape, shape_, ndim * sizeof(size_t));
253+
251254
cl::sycl::range<1> gws(size);
252255
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
253256
size_t i = global_id[0];
254-
_DataType elem = 0;
257+
result[i] = 0;
255258
for (size_t j = 0; j < shape[ndim - 1]; ++j)
256259
{
257-
elem += array_in[i * shape[ndim - 1] + j];
260+
result[i] += array_in[i * shape[ndim - 1] + j];
258261
}
259-
result[i] = elem;
260262
};
261263

262264
auto kernel_func = [&](cl::sycl::handler& cgh) {
265+
cgh.depends_on({memcpy_event});
263266
cgh.parallel_for<class dpnp_trace_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
264267
};
265268

266269
event = DPNP_QUEUE.submit(kernel_func);
267270

268271
event.wait();
272+
273+
dpnp_memory_free_c(shape);
269274
}
270275

271276
template <typename _DataType>

0 commit comments

Comments
 (0)