@@ -290,10 +290,6 @@ class DPNPC_id final
290
290
get_shape_offsets_inkernel<size_type>(output_shape, output_shape_size, output_shape_strides);
291
291
292
292
iteration_size = 1 ;
293
-
294
- // make thread private storage for each shape by multiplying memory
295
- sycl_output_xyz =
296
- reinterpret_cast <size_type*>(dpnp_memory_alloc_c (output_size * output_shape_size_in_bytes));
297
293
}
298
294
}
299
295
@@ -400,10 +396,6 @@ class DPNPC_id final
400
396
{
401
397
axes_shape_strides[i] = input_shape_strides[axes[i]];
402
398
}
403
-
404
- // make thread private storage for each shape by multiplying memory
405
- sycl_output_xyz =
406
- reinterpret_cast <size_type*>(dpnp_memory_alloc_c (output_size * output_shape_size_in_bytes));
407
399
}
408
400
}
409
401
@@ -485,35 +477,30 @@ class DPNPC_id final
485
477
{
486
478
assert (output_global_id < output_size);
487
479
488
- // use thread private storage
489
- size_type* sycl_output_xyz_thread = sycl_output_xyz + (output_global_id * output_shape_size);
490
-
491
- get_xyz_by_id_inkernel (output_global_id, output_shape_strides, output_shape_size, sycl_output_xyz_thread);
492
-
493
480
for (size_t iit = 0 , oit = 0 ; iit < input_shape_size; ++iit)
494
481
{
495
482
if (std::find (axes.begin (), axes.end (), iit) == axes.end ())
496
483
{
497
- input_global_id += (sycl_output_xyz_thread[oit] * input_shape_strides[iit]);
484
+ const size_type output_xyz_id = get_xyz_id_by_id_inkernel (output_global_id, output_shape_strides,
485
+ output_shape_size, oit);
486
+ input_global_id += (output_xyz_id * input_shape_strides[iit]);
498
487
++oit;
499
488
}
500
489
}
501
490
}
502
491
else if (broadcast_use)
503
492
{
504
493
assert (output_global_id < output_size);
505
-
506
- // use thread private storage
507
- size_type* sycl_output_xyz_thread = sycl_output_xyz + (output_global_id * output_shape_size);
508
-
509
- get_xyz_by_id_inkernel (output_global_id, output_shape_strides, output_shape_size, sycl_output_xyz_thread);
494
+ assert (input_shape_size <= output_shape_size);
510
495
511
496
for (int irit = input_shape_size - 1 , orit = output_shape_size - 1 ; irit >= 0 ; --irit, --orit)
512
497
{
513
498
size_type* broadcast_axes_end = broadcast_axes + broadcast_axes_size;
514
499
if (std::find (broadcast_axes, broadcast_axes_end, orit) == broadcast_axes_end)
515
500
{
516
- input_global_id += (sycl_output_xyz_thread[orit] * input_shape_strides[irit]);
501
+ const size_type output_xyz_id = get_xyz_id_by_id_inkernel (output_global_id, output_shape_strides,
502
+ output_shape_size, orit);
503
+ input_global_id += (output_xyz_id * input_shape_strides[irit]);
517
504
}
518
505
}
519
506
}
@@ -565,10 +552,8 @@ class DPNPC_id final
565
552
output_shape_size = size_type{};
566
553
dpnp_memory_free_c (output_shape);
567
554
dpnp_memory_free_c (output_shape_strides);
568
- dpnp_memory_free_c (sycl_output_xyz);
569
555
output_shape = nullptr ;
570
556
output_shape_strides = nullptr ;
571
- sycl_output_xyz = nullptr ;
572
557
}
573
558
574
559
void free_memory ()
@@ -602,9 +587,6 @@ class DPNPC_id final
602
587
size_type iteration_shape_size = size_type{};
603
588
size_type* iteration_shape_strides = nullptr ;
604
589
size_type* axes_shape_strides = nullptr ;
605
-
606
- // data allocated to use inside SYCL kernels
607
- size_type* sycl_output_xyz = nullptr ;
608
590
};
609
591
610
592
#endif // DPNP_ITERATOR_H
0 commit comments