Skip to content

Commit 572bc50

Browse files
[SYCL][Reduction] Support range version with multiple reductions (#7456)
1 parent f5f512b commit 572bc50

File tree

3 files changed

+64
-44
lines changed

3 files changed

+64
-44
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,25 +2029,24 @@ class __SYCL_EXPORT handler {
20292029

20302030
/// Reductions @{
20312031

2032-
template <typename KernelName = detail::auto_name, typename KernelType,
2033-
typename PropertiesT, int Dims, typename Reduction>
2032+
template <typename KernelName = detail::auto_name, int Dims,
2033+
typename PropertiesT, typename... RestT>
20342034
std::enable_if_t<
2035-
detail::IsReduction<Reduction>::value &&
2035+
(sizeof...(RestT) > 1) &&
2036+
detail::AreAllButLastReductions<RestT...>::value &&
20362037
ext::oneapi::experimental::is_property_list<PropertiesT>::value>
2037-
parallel_for(range<Dims> Range, PropertiesT Properties, Reduction Redu,
2038-
_KERNELFUNCPARAM(KernelFunc)) {
2039-
detail::reduction_parallel_for<KernelName>(*this, Range, Properties, Redu,
2040-
std::move(KernelFunc));
2038+
parallel_for(range<Dims> Range, PropertiesT Properties, RestT &&...Rest) {
2039+
detail::reduction_parallel_for<KernelName>(*this, Range, Properties,
2040+
std::forward<RestT>(Rest)...);
20412041
}
20422042

2043-
template <typename KernelName = detail::auto_name, typename KernelType,
2044-
int Dims, typename Reduction>
2045-
std::enable_if_t<detail::IsReduction<Reduction>::value>
2046-
parallel_for(range<Dims> Range, Reduction Redu,
2047-
_KERNELFUNCPARAM(KernelFunc)) {
2043+
template <typename KernelName = detail::auto_name, int Dims,
2044+
typename... RestT>
2045+
std::enable_if_t<detail::AreAllButLastReductions<RestT...>::value>
2046+
parallel_for(range<Dims> Range, RestT &&...Rest) {
20482047
parallel_for<KernelName>(
2049-
Range, ext::oneapi::experimental::detail::empty_properties_t{}, Redu,
2050-
std::move(KernelFunc));
2048+
Range, ext::oneapi::experimental::detail::empty_properties_t{},
2049+
std::forward<RestT>(Rest)...);
20512050
}
20522051

20532052
template <typename KernelName = detail::auto_name, int Dims,
@@ -2520,11 +2519,10 @@ class __SYCL_EXPORT handler {
25202519
friend void detail::reduction::withAuxHandler(handler &CGH, FunctorTy Func);
25212520

25222521
template <typename KernelName, detail::reduction::strategy Strategy, int Dims,
2523-
typename PropertiesT, typename KernelType, typename Reduction>
2524-
friend void detail::reduction_parallel_for(handler &CGH, range<Dims> Range,
2522+
typename PropertiesT, typename... RestT>
2523+
friend void detail::reduction_parallel_for(handler &CGH, range<Dims> NDRange,
25252524
PropertiesT Properties,
2526-
Reduction Redu,
2527-
KernelType KernelFunc);
2525+
RestT... Rest);
25282526

25292527
template <typename KernelName, detail::reduction::strategy Strategy, int Dims,
25302528
typename PropertiesT, typename... RestT>

sycl/include/sycl/reduction.hpp

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,16 +2302,29 @@ __SYCL_EXPORT uint32_t
23022302
reduGetMaxNumConcurrentWorkGroups(std::shared_ptr<queue_impl> Queue);
23032303

23042304
template <typename KernelName, reduction::strategy Strategy, int Dims,
2305-
typename PropertiesT, typename KernelType, typename Reduction>
2305+
typename PropertiesT, typename... RestT>
23062306
void reduction_parallel_for(handler &CGH, range<Dims> Range,
2307-
PropertiesT Properties, Reduction Redu,
2308-
KernelType KernelFunc) {
2307+
PropertiesT Properties, RestT... Rest) {
2308+
std::tuple<RestT...> ArgsTuple(Rest...);
2309+
constexpr size_t NumArgs = sizeof...(RestT);
2310+
static_assert(NumArgs > 1, "No reduction!");
2311+
auto KernelFunc = std::get<NumArgs - 1>(ArgsTuple);
2312+
auto ReduIndices = std::make_index_sequence<NumArgs - 1>();
2313+
auto ReduTuple = detail::tuple_select_elements(ArgsTuple, ReduIndices);
2314+
23092315
// Before running the kernels, check that device has enough local memory
23102316
// to hold local arrays required for the tree-reduction algorithm.
2311-
constexpr bool IsTreeReduction =
2312-
!Reduction::has_fast_reduce && !Reduction::has_fast_atomics;
2313-
size_t OneElemSize =
2314-
IsTreeReduction ? sizeof(typename Reduction::result_type) : 0;
2317+
size_t OneElemSize = [&]() {
2318+
if constexpr (NumArgs == 2) {
2319+
using Reduction = std::tuple_element_t<0, decltype(ReduTuple)>;
2320+
constexpr bool IsTreeReduction =
2321+
!Reduction::has_fast_reduce && !Reduction::has_fast_atomics;
2322+
return IsTreeReduction ? sizeof(typename Reduction::result_type) : 0;
2323+
} else {
2324+
return reduGetMemPerWorkItem(ReduTuple, ReduIndices);
2325+
}
2326+
}();
2327+
23152328
uint32_t NumConcurrentWorkGroups =
23162329
#ifdef __SYCL_REDUCTION_NUM_CONCURRENT_WORKGROUPS
23172330
__SYCL_REDUCTION_NUM_CONCURRENT_WORKGROUPS;
@@ -2341,7 +2354,7 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
23412354
// stride equal to 1. For each of the index the given the original KernelFunc
23422355
// is called and the reduction value hold in \p Reducer is accumulated in
23432356
// those calls.
2344-
auto UpdatedKernelFunc = [=](auto NDId, auto &Reducer) {
2357+
auto UpdatedKernelFunc = [=](auto NDId, auto &...Reducers) {
23452358
// Divide into contiguous chunks and assign each chunk to a Group
23462359
// Rely on precomputed division to avoid repeating expensive operations
23472360
// TODO: Some devices may prefer alternative remainder handling
@@ -2357,23 +2370,34 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
23572370
size_t End = GroupEnd;
23582371
size_t Stride = NDId.get_local_range(0);
23592372
for (size_t I = Start; I < End; I += Stride)
2360-
KernelFunc(getDelinearizedId(Range, I), Reducer);
2373+
KernelFunc(getDelinearizedId(Range, I), Reducers...);
23612374
};
2375+
if constexpr (NumArgs == 2) {
2376+
using Reduction = std::tuple_element_t<0, decltype(ReduTuple)>;
2377+
auto &Redu = std::get<0>(ReduTuple);
23622378

2363-
constexpr auto StrategyToUse = [&]() {
2364-
if constexpr (Strategy != reduction::strategy::auto_select)
2365-
return Strategy;
2379+
constexpr auto StrategyToUse = [&]() {
2380+
if constexpr (Strategy != reduction::strategy::auto_select)
2381+
return Strategy;
23662382

2367-
if constexpr (Reduction::has_fast_reduce)
2368-
return reduction::strategy::group_reduce_and_last_wg_detection;
2369-
else if constexpr (Reduction::has_fast_atomics)
2370-
return reduction::strategy::local_atomic_and_atomic_cross_wg;
2371-
else
2372-
return reduction::strategy::range_basic;
2373-
}();
2383+
if constexpr (Reduction::has_fast_reduce)
2384+
return reduction::strategy::group_reduce_and_last_wg_detection;
2385+
else if constexpr (Reduction::has_fast_atomics)
2386+
return reduction::strategy::local_atomic_and_atomic_cross_wg;
2387+
else
2388+
return reduction::strategy::range_basic;
2389+
}();
23742390

2375-
reduction_parallel_for<KernelName, StrategyToUse>(CGH, NDRange, Properties,
2376-
Redu, UpdatedKernelFunc);
2391+
reduction_parallel_for<KernelName, StrategyToUse>(CGH, NDRange, Properties,
2392+
Redu, UpdatedKernelFunc);
2393+
} else {
2394+
return std::apply(
2395+
[&](auto &...Reds) {
2396+
return reduction_parallel_for<KernelName, Strategy>(
2397+
CGH, NDRange, Properties, Reds..., UpdatedKernelFunc);
2398+
},
2399+
ReduTuple);
2400+
}
23772401
}
23782402
} // namespace detail
23792403

sycl/include/sycl/reduction_forward.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,9 @@ template <class FunctorTy> void withAuxHandler(handler &CGH, FunctorTy Func);
4646

4747
template <typename KernelName,
4848
reduction::strategy Strategy = reduction::strategy::auto_select,
49-
int Dims, typename PropertiesT, typename KernelType,
50-
typename Reduction>
51-
void reduction_parallel_for(handler &CGH, range<Dims> Range,
52-
PropertiesT Properties, Reduction Redu,
53-
KernelType KernelFunc);
49+
int Dims, typename PropertiesT, typename... RestT>
50+
void reduction_parallel_for(handler &CGH, range<Dims> NDRange,
51+
PropertiesT Properties, RestT... Rest);
5452

5553
template <typename KernelName,
5654
reduction::strategy Strategy = reduction::strategy::auto_select,

0 commit comments

Comments
 (0)