Skip to content

Commit 44a26c1

Browse files
[SYCL][Reduction] Fix return type of reduction combine (#8125)
The combine member of reducers and the operator shortcuts should return a reference to the reducer. This commit fixes this return type. Signed-off-by: Larsen, Steffen <steffen.larsen@intel.com>
1 parent 4481ab2 commit 44a26c1

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -179,46 +179,48 @@ template <class Reducer> class combiner {
179179
public:
180180
template <typename _T = Ty, int _Dims = Dims>
181181
enable_if_t<(_Dims == 0) && IsPlus<_T, BinaryOp>::value &&
182-
is_geninteger<_T>::value>
182+
is_geninteger<_T>::value,
183+
Reducer &>
183184
operator++() {
184-
static_cast<Reducer *>(this)->combine(static_cast<_T>(1));
185+
return static_cast<Reducer *>(this)->combine(static_cast<_T>(1));
185186
}
186187

187188
template <typename _T = Ty, int _Dims = Dims>
188189
enable_if_t<(_Dims == 0) && IsPlus<_T, BinaryOp>::value &&
189-
is_geninteger<_T>::value>
190+
is_geninteger<_T>::value,
191+
Reducer &>
190192
operator++(int) {
191-
static_cast<Reducer *>(this)->combine(static_cast<_T>(1));
193+
return static_cast<Reducer *>(this)->combine(static_cast<_T>(1));
192194
}
193195

194196
template <typename _T = Ty, int _Dims = Dims>
195-
enable_if_t<(_Dims == 0) && IsPlus<_T, BinaryOp>::value>
197+
enable_if_t<(_Dims == 0) && IsPlus<_T, BinaryOp>::value, Reducer &>
196198
operator+=(const _T &Partial) {
197-
static_cast<Reducer *>(this)->combine(Partial);
199+
return static_cast<Reducer *>(this)->combine(Partial);
198200
}
199201

200202
template <typename _T = Ty, int _Dims = Dims>
201-
enable_if_t<(_Dims == 0) && IsMultiplies<_T, BinaryOp>::value>
203+
enable_if_t<(_Dims == 0) && IsMultiplies<_T, BinaryOp>::value, Reducer &>
202204
operator*=(const _T &Partial) {
203-
static_cast<Reducer *>(this)->combine(Partial);
205+
return static_cast<Reducer *>(this)->combine(Partial);
204206
}
205207

206208
template <typename _T = Ty, int _Dims = Dims>
207-
enable_if_t<(_Dims == 0) && IsBitOR<_T, BinaryOp>::value>
209+
enable_if_t<(_Dims == 0) && IsBitOR<_T, BinaryOp>::value, Reducer &>
208210
operator|=(const _T &Partial) {
209-
static_cast<Reducer *>(this)->combine(Partial);
211+
return static_cast<Reducer *>(this)->combine(Partial);
210212
}
211213

212214
template <typename _T = Ty, int _Dims = Dims>
213-
enable_if_t<(_Dims == 0) && IsBitXOR<_T, BinaryOp>::value>
215+
enable_if_t<(_Dims == 0) && IsBitXOR<_T, BinaryOp>::value, Reducer &>
214216
operator^=(const _T &Partial) {
215-
static_cast<Reducer *>(this)->combine(Partial);
217+
return static_cast<Reducer *>(this)->combine(Partial);
216218
}
217219

218220
template <typename _T = Ty, int _Dims = Dims>
219-
enable_if_t<(_Dims == 0) && IsBitAND<_T, BinaryOp>::value>
221+
enable_if_t<(_Dims == 0) && IsBitAND<_T, BinaryOp>::value, Reducer &>
220222
operator&=(const _T &Partial) {
221-
static_cast<Reducer *>(this)->combine(Partial);
223+
return static_cast<Reducer *>(this)->combine(Partial);
222224
}
223225

224226
private:
@@ -339,7 +341,10 @@ class reducer<
339341
reducer(const T &Identity, BinaryOperation BOp)
340342
: MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
341343

342-
void combine(const T &Partial) { MValue = MBinaryOp(MValue, Partial); }
344+
reducer &combine(const T &Partial) {
345+
MValue = MBinaryOp(MValue, Partial);
346+
return *this;
347+
}
343348

344349
T getIdentity() const { return MIdentity; }
345350

@@ -371,9 +376,10 @@ class reducer<
371376
reducer() : MValue(getIdentity()) {}
372377
reducer(const T & /* Identity */, BinaryOperation) : MValue(getIdentity()) {}
373378

374-
void combine(const T &Partial) {
379+
reducer &combine(const T &Partial) {
375380
BinaryOperation BOp;
376381
MValue = BOp(MValue, Partial);
382+
return *this;
377383
}
378384

379385
static T getIdentity() {
@@ -396,7 +402,10 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
396402
public:
397403
reducer(T &Ref, BinaryOperation BOp) : MElement(Ref), MBinaryOp(BOp) {}
398404

399-
void combine(const T &Partial) { MElement = MBinaryOp(MElement, Partial); }
405+
reducer &combine(const T &Partial) {
406+
MElement = MBinaryOp(MElement, Partial);
407+
return *this;
408+
}
400409

401410
private:
402411
T &MElement;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %clangxx -fsycl -fsyntax-only -sycl-std=2020 %s
2+
3+
// Tests the return type of combiner operations on reducers.
4+
5+
#include <sycl/sycl.hpp>
6+
7+
#include <type_traits>
8+
9+
int main() {
10+
sycl::queue Q;
11+
12+
int *PlusMem = sycl::malloc_device<int>(1, Q);
13+
int *MultMem = sycl::malloc_device<int>(1, Q);
14+
int *BitAndMem = sycl::malloc_device<int>(1, Q);
15+
int *BitOrMem = sycl::malloc_device<int>(1, Q);
16+
int *BitXorMem = sycl::malloc_device<int>(1, Q);
17+
Q.submit([&](sycl::handler &CGH) {
18+
auto PlusReduction = sycl::reduction(PlusMem, sycl::plus<int>());
19+
auto MultReduction = sycl::reduction(MultMem, sycl::multiplies<int>());
20+
auto BitAndReduction = sycl::reduction(BitAndMem, sycl::bit_and<int>());
21+
auto BitOrReduction = sycl::reduction(BitOrMem, sycl::bit_or<int>());
22+
auto BitXorReduction = sycl::reduction(PlusMem, sycl::bit_xor<int>());
23+
CGH.parallel_for(sycl::range<1>(10), PlusReduction, MultReduction,
24+
BitAndReduction, BitOrReduction, BitXorReduction,
25+
[=](sycl::id<1>, auto &Plus, auto &Mult, auto &BitAnd,
26+
auto &BitOr, auto &BitXor) {
27+
(Plus.combine(1) += 1).combine(1);
28+
(Plus.combine(1)++).combine(1);
29+
(++Plus.combine(1)).combine(1);
30+
(Mult.combine(1) *= 1).combine(1);
31+
(BitAnd.combine(1) &= 1).combine(1);
32+
(BitOr.combine(1) |= 1).combine(1);
33+
(BitXor.combine(1) ^= 1).combine(1);
34+
});
35+
});
36+
sycl::free(PlusMem, Q);
37+
sycl::free(MultMem, Q);
38+
sycl::free(BitAndMem, Q);
39+
sycl::free(BitOrMem, Q);
40+
sycl::free(BitXorMem, Q);
41+
return 0;
42+
}

0 commit comments

Comments
 (0)