Skip to content

Commit 5d5e9f4

Browse files
[SYCL] Support kernels accepting item in range reduction parallel_for (#7478)
Previously only sycl::id worked.
1 parent 39b6672 commit 5d5e9f4

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

sycl/include/sycl/item.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@ template <typename TransformedArgType, int Dims, typename KernelType>
2525
class RoundedRangeKernel;
2626
template <typename TransformedArgType, int Dims, typename KernelType>
2727
class RoundedRangeKernelWithKH;
28+
29+
namespace reduction {
30+
template <int Dims>
31+
item<Dims, false> getDelinearizedItem(range<Dims> Range, id<Dims> Id);
32+
} // namespace reduction
2833
} // namespace detail
29-
template <int dimensions> class id;
30-
template <int dimensions> class range;
3134

3235
/// Identifies an instance of the function object executing at each point
3336
/// in a range.
@@ -130,6 +133,10 @@ template <int dimensions = 1, bool with_offset = true> class item {
130133
friend class detail::RoundedRangeKernelWithKH;
131134
void set_allowed_range(const range<dimensions> rnwi) { MImpl.MExtent = rnwi; }
132135

136+
template <int Dims>
137+
friend item<Dims, false>
138+
detail::reduction::getDelinearizedItem(range<Dims> Range, id<Dims> Id);
139+
133140
detail::ItemBase<dimensions, with_offset> MImpl;
134141
};
135142

sycl/include/sycl/reduction.hpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2368,8 +2368,18 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
23682368
size_t Start = GroupStart + NDId.get_local_id(0);
23692369
size_t End = GroupEnd;
23702370
size_t Stride = NDId.get_local_range(0);
2371+
auto GetDelinearized = [&](size_t I) {
2372+
auto Id = getDelinearizedId(Range, I);
2373+
if constexpr (std::is_invocable_v<decltype(KernelFunc), id<Dims>,
2374+
decltype(Reducers)...>)
2375+
return Id;
2376+
else
2377+
// SYCL doesn't provide parallel_for accepting offset in presence of
2378+
// reductions, so use with_offset==false.
2379+
return reduction::getDelinearizedItem(Range, Id);
2380+
};
23712381
for (size_t I = Start; I < End; I += Stride)
2372-
KernelFunc(getDelinearizedId(Range, I), Reducers...);
2382+
KernelFunc(GetDelinearized(I), Reducers...);
23732383
};
23742384
if constexpr (NumArgs == 2) {
23752385
using Reduction = std::tuple_element_t<0, decltype(ReduTuple)>;

sycl/include/sycl/reduction_forward.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ enum class strategy : int {
4242
// are limited to those below.
4343
inline void finalizeHandler(handler &CGH);
4444
template <class FunctorTy> void withAuxHandler(handler &CGH, FunctorTy Func);
45+
46+
template <int Dims>
47+
item<Dims, false> getDelinearizedItem(range<Dims> Range, id<Dims> Id) {
48+
return {Range, Id};
49+
}
4550
} // namespace reduction
4651

4752
template <typename KernelName,

0 commit comments

Comments
 (0)