26
26
#include < cstddef>
27
27
#include < cstdint>
28
28
#include < limits>
29
- #include < sycl/sycl.hpp>
30
29
#include < utility>
31
30
#include < vector>
32
31
32
+ #include < sycl/sycl.hpp>
33
+
33
34
#include " dpctl_tensor_types.hpp"
34
35
#include " utils/offset_utils.hpp"
35
36
#include " utils/type_dispatch_building.hpp"
@@ -599,6 +600,10 @@ sycl::event masked_place_all_slices_strided_impl(
599
600
sycl::nd_range<2 > ndRange{gRange , lRange};
600
601
601
602
using LocalAccessorT = sycl::local_accessor<indT, 1 >;
603
+ using Impl =
604
+ MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
605
+ Strided1DCyclicIndexer, dataT, indT,
606
+ LocalAccessorT>;
602
607
603
608
dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
604
609
const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
@@ -611,13 +616,9 @@ sycl::event masked_place_all_slices_strided_impl(
611
616
LocalAccessorT lacc (lacc_size, cgh);
612
617
613
618
cgh.parallel_for <KernelName>(
614
- ndRange,
615
- MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
616
- Strided1DCyclicIndexer, dataT, indT,
617
- LocalAccessorT>(
618
- dst_tp, cumsum_tp, rhs_tp, iteration_size,
619
- orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
620
- lacc));
619
+ ndRange, Impl (dst_tp, cumsum_tp, rhs_tp, iteration_size,
620
+ orthog_dst_rhs_indexer, masked_dst_indexer,
621
+ masked_rhs_indexer, lacc));
621
622
});
622
623
623
624
return comp_ev;
@@ -696,6 +697,10 @@ sycl::event masked_place_some_slices_strided_impl(
696
697
sycl::nd_range<2 > ndRange{gRange , lRange};
697
698
698
699
using LocalAccessorT = sycl::local_accessor<indT, 1 >;
700
+ using Impl =
701
+ MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
702
+ Strided1DCyclicIndexer, dataT, indT,
703
+ LocalAccessorT>;
699
704
700
705
dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
701
706
const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
@@ -708,13 +713,9 @@ sycl::event masked_place_some_slices_strided_impl(
708
713
LocalAccessorT lacc (lacc_size, cgh);
709
714
710
715
cgh.parallel_for <KernelName>(
711
- ndRange,
712
- MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
713
- Strided1DCyclicIndexer, dataT, indT,
714
- LocalAccessorT>(
715
- dst_tp, cumsum_tp, rhs_tp, masked_nelems,
716
- orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
717
- lacc));
716
+ ndRange, Impl (dst_tp, cumsum_tp, rhs_tp, masked_nelems,
717
+ orthog_dst_rhs_indexer, masked_dst_indexer,
718
+ masked_rhs_indexer, lacc));
718
719
});
719
720
720
721
return comp_ev;
0 commit comments