@@ -215,7 +215,7 @@ template <typename _DataType, typename _ResultType>
215
215
class dpnp_trace_c_kernel ;
216
216
217
217
template <typename _DataType, typename _ResultType>
218
- void dpnp_trace_c (const void * array1_in, void * result1, const size_t * shape , const size_t ndim)
218
+ void dpnp_trace_c (const void * array1_in, void * result1, const size_t * shape_ , const size_t ndim)
219
219
{
220
220
cl::sycl::event event;
221
221
@@ -227,7 +227,7 @@ void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape, con
227
227
const _DataType* array_in = reinterpret_cast <const _DataType*>(array1_in);
228
228
_ResultType* result = reinterpret_cast <_ResultType*>(result1);
229
229
230
- if (shape == nullptr )
230
+ if (shape_ == nullptr )
231
231
{
232
232
return ;
233
233
}
@@ -240,32 +240,37 @@ void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape, con
240
240
size_t size = 1 ;
241
241
for (size_t i = 0 ; i < ndim - 1 ; ++i)
242
242
{
243
- size *= shape [i];
243
+ size *= shape_ [i];
244
244
}
245
245
246
246
if (size == 0 )
247
247
{
248
248
return ;
249
249
}
250
250
251
+ size_t * shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
252
+ auto memcpy_event = DPNP_QUEUE.memcpy (shape, shape_, ndim * sizeof (size_t ));
253
+
251
254
cl::sycl::range<1 > gws (size);
252
255
auto kernel_parallel_for_func = [=](cl::sycl::id<1 > global_id) {
253
256
size_t i = global_id[0 ];
254
- _DataType elem = 0 ;
257
+ result[i] = 0 ;
255
258
for (size_t j = 0 ; j < shape[ndim - 1 ]; ++j)
256
259
{
257
- elem += array_in[i * shape[ndim - 1 ] + j];
260
+ result[i] += array_in[i * shape[ndim - 1 ] + j];
258
261
}
259
- result[i] = elem;
260
262
};
261
263
262
264
auto kernel_func = [&](cl::sycl::handler& cgh) {
265
+ cgh.depends_on ({memcpy_event});
263
266
cgh.parallel_for <class dpnp_trace_c_kernel <_DataType, _ResultType>>(gws, kernel_parallel_for_func);
264
267
};
265
268
266
269
event = DPNP_QUEUE.submit (kernel_func);
267
270
268
271
event.wait ();
272
+
273
+ dpnp_memory_free_c (shape);
269
274
}
270
275
271
276
template <typename _DataType>
0 commit comments