@@ -2302,16 +2302,29 @@ __SYCL_EXPORT uint32_t
2302
2302
reduGetMaxNumConcurrentWorkGroups (std::shared_ptr<queue_impl> Queue);
2303
2303
2304
2304
template <typename KernelName, reduction::strategy Strategy, int Dims,
2305
- typename PropertiesT, typename KernelType, typename Reduction >
2305
+ typename PropertiesT, typename ... RestT >
2306
2306
void reduction_parallel_for (handler &CGH, range<Dims> Range,
2307
- PropertiesT Properties, Reduction Redu,
2308
- KernelType KernelFunc) {
2307
+ PropertiesT Properties, RestT... Rest) {
2308
+ std::tuple<RestT...> ArgsTuple (Rest...);
2309
+ constexpr size_t NumArgs = sizeof ...(RestT);
2310
+ static_assert (NumArgs > 1 , " No reduction!" );
2311
+ auto KernelFunc = std::get<NumArgs - 1 >(ArgsTuple);
2312
+ auto ReduIndices = std::make_index_sequence<NumArgs - 1 >();
2313
+ auto ReduTuple = detail::tuple_select_elements (ArgsTuple, ReduIndices);
2314
+
2309
2315
// Before running the kernels, check that device has enough local memory
2310
2316
// to hold local arrays required for the tree-reduction algorithm.
2311
- constexpr bool IsTreeReduction =
2312
- !Reduction::has_fast_reduce && !Reduction::has_fast_atomics;
2313
- size_t OneElemSize =
2314
- IsTreeReduction ? sizeof (typename Reduction::result_type) : 0 ;
2317
+ size_t OneElemSize = [&]() {
2318
+ if constexpr (NumArgs == 2 ) {
2319
+ using Reduction = std::tuple_element_t <0 , decltype (ReduTuple)>;
2320
+ constexpr bool IsTreeReduction =
2321
+ !Reduction::has_fast_reduce && !Reduction::has_fast_atomics;
2322
+ return IsTreeReduction ? sizeof (typename Reduction::result_type) : 0 ;
2323
+ } else {
2324
+ return reduGetMemPerWorkItem (ReduTuple, ReduIndices);
2325
+ }
2326
+ }();
2327
+
2315
2328
uint32_t NumConcurrentWorkGroups =
2316
2329
#ifdef __SYCL_REDUCTION_NUM_CONCURRENT_WORKGROUPS
2317
2330
__SYCL_REDUCTION_NUM_CONCURRENT_WORKGROUPS;
@@ -2341,7 +2354,7 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
2341
2354
// stride equal to 1. For each of the index the given the original KernelFunc
2342
2355
// is called and the reduction value hold in \p Reducer is accumulated in
2343
2356
// those calls.
2344
- auto UpdatedKernelFunc = [=](auto NDId, auto &Reducer ) {
2357
+ auto UpdatedKernelFunc = [=](auto NDId, auto &... Reducers ) {
2345
2358
// Divide into contiguous chunks and assign each chunk to a Group
2346
2359
// Rely on precomputed division to avoid repeating expensive operations
2347
2360
// TODO: Some devices may prefer alternative remainder handling
@@ -2357,23 +2370,34 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
2357
2370
size_t End = GroupEnd;
2358
2371
size_t Stride = NDId.get_local_range (0 );
2359
2372
for (size_t I = Start; I < End; I += Stride)
2360
- KernelFunc (getDelinearizedId (Range, I), Reducer );
2373
+ KernelFunc (getDelinearizedId (Range, I), Reducers... );
2361
2374
};
2375
+ if constexpr (NumArgs == 2 ) {
2376
+ using Reduction = std::tuple_element_t <0 , decltype (ReduTuple)>;
2377
+ auto &Redu = std::get<0 >(ReduTuple);
2362
2378
2363
- constexpr auto StrategyToUse = [&]() {
2364
- if constexpr (Strategy != reduction::strategy::auto_select)
2365
- return Strategy;
2379
+ constexpr auto StrategyToUse = [&]() {
2380
+ if constexpr (Strategy != reduction::strategy::auto_select)
2381
+ return Strategy;
2366
2382
2367
- if constexpr (Reduction::has_fast_reduce)
2368
- return reduction::strategy::group_reduce_and_last_wg_detection;
2369
- else if constexpr (Reduction::has_fast_atomics)
2370
- return reduction::strategy::local_atomic_and_atomic_cross_wg;
2371
- else
2372
- return reduction::strategy::range_basic;
2373
- }();
2383
+ if constexpr (Reduction::has_fast_reduce)
2384
+ return reduction::strategy::group_reduce_and_last_wg_detection;
2385
+ else if constexpr (Reduction::has_fast_atomics)
2386
+ return reduction::strategy::local_atomic_and_atomic_cross_wg;
2387
+ else
2388
+ return reduction::strategy::range_basic;
2389
+ }();
2374
2390
2375
- reduction_parallel_for<KernelName, StrategyToUse>(CGH, NDRange, Properties,
2376
- Redu, UpdatedKernelFunc);
2391
+ reduction_parallel_for<KernelName, StrategyToUse>(CGH, NDRange, Properties,
2392
+ Redu, UpdatedKernelFunc);
2393
+ } else {
2394
+ return std::apply (
2395
+ [&](auto &...Reds ) {
2396
+ return reduction_parallel_for<KernelName, Strategy>(
2397
+ CGH, NDRange, Properties, Reds..., UpdatedKernelFunc);
2398
+ },
2399
+ ReduTuple);
2400
+ }
2377
2401
}
2378
2402
} // namespace detail
2379
2403
0 commit comments