Skip to content

Commit 5c412df

Browse files
Save functor type into typename to simplify invocation in parallel_for
1 parent 0399b93 commit 5c412df

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
#include <cstddef>
2727
#include <cstdint>
2828
#include <limits>
29-
#include <sycl/sycl.hpp>
3029
#include <utility>
3130
#include <vector>
3231

32+
#include <sycl/sycl.hpp>
33+
3334
#include "dpctl_tensor_types.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch_building.hpp"
@@ -599,6 +600,10 @@ sycl::event masked_place_all_slices_strided_impl(
599600
sycl::nd_range<2> ndRange{gRange, lRange};
600601

601602
using LocalAccessorT = sycl::local_accessor<indT, 1>;
603+
using Impl =
604+
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
605+
Strided1DCyclicIndexer, dataT, indT,
606+
LocalAccessorT>;
602607

603608
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
604609
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
@@ -611,13 +616,9 @@ sycl::event masked_place_all_slices_strided_impl(
611616
LocalAccessorT lacc(lacc_size, cgh);
612617

613618
cgh.parallel_for<KernelName>(
614-
ndRange,
615-
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
616-
Strided1DCyclicIndexer, dataT, indT,
617-
LocalAccessorT>(
618-
dst_tp, cumsum_tp, rhs_tp, iteration_size,
619-
orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
620-
lacc));
619+
ndRange, Impl(dst_tp, cumsum_tp, rhs_tp, iteration_size,
620+
orthog_dst_rhs_indexer, masked_dst_indexer,
621+
masked_rhs_indexer, lacc));
621622
});
622623

623624
return comp_ev;
@@ -696,6 +697,10 @@ sycl::event masked_place_some_slices_strided_impl(
696697
sycl::nd_range<2> ndRange{gRange, lRange};
697698

698699
using LocalAccessorT = sycl::local_accessor<indT, 1>;
700+
using Impl =
701+
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
702+
Strided1DCyclicIndexer, dataT, indT,
703+
LocalAccessorT>;
699704

700705
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
701706
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
@@ -708,13 +713,9 @@ sycl::event masked_place_some_slices_strided_impl(
708713
LocalAccessorT lacc(lacc_size, cgh);
709714

710715
cgh.parallel_for<KernelName>(
711-
ndRange,
712-
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
713-
Strided1DCyclicIndexer, dataT, indT,
714-
LocalAccessorT>(
715-
dst_tp, cumsum_tp, rhs_tp, masked_nelems,
716-
orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
717-
lacc));
716+
ndRange, Impl(dst_tp, cumsum_tp, rhs_tp, masked_nelems,
717+
orthog_dst_rhs_indexer, masked_dst_indexer,
718+
masked_rhs_indexer, lacc));
718719
});
719720

720721
return comp_ev;

0 commit comments

Comments
 (0)