Skip to content

Commit 94c4b80

Browse files
[SYCL] Make reduction compatible with MSVC host compiler (#6601)
This PR addresses two problems: 1) MSVC has a bug in handling this in default mode (fine in /permissive- though). The issue affected compilation using MSVC as a host compiler. Simplified description of the problem: template <class Derived> class Base { using T = int; }; template <class T> class A : public Base<A<T>> { // That's what we had in the codebase prior to this change. MSVC // complains here by default, accepts in "/permissive-". using T2 = T; }; class Base2 { using T = int; }; template <class T> class A2 : public Base2 { using T2 = T; // That's where the error has to be emitted. }; int main() { A<int> a; A2<int> a2; return 0; } 2) constexpr variable are part of lambda capture and result in incompatibilities between clang device and MSVC host. As such, don't use those when they're used inside kernel lambdas. Instead, make them regular variable and pay the price of increased number of arguments to the kernel.
1 parent 36e7587 commit 94c4b80

File tree

1 file changed

+53
-58
lines changed

1 file changed

+53
-58
lines changed

sycl/include/sycl/ext/oneapi/reduction.hpp

Lines changed: 53 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -192,57 +192,52 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
192192
/// Also, for int32/64 types the atomic_combine() is lowered to
193193
/// sycl::atomic::fetch_add().
194194
template <class Reducer> class combiner {
195-
using T = typename ReducerTraits<Reducer>::type;
196-
using BinaryOperation = typename ReducerTraits<Reducer>::op;
195+
using Ty = typename ReducerTraits<Reducer>::type;
196+
using BinaryOp = typename ReducerTraits<Reducer>::op;
197197
static constexpr int Dims = ReducerTraits<Reducer>::dims;
198198
static constexpr size_t Extent = ReducerTraits<Reducer>::extent;
199199

200200
public:
201-
template <typename _T = T, int _Dims = Dims>
202-
enable_if_t<(_Dims == 0) &&
203-
sycl::detail::IsPlus<_T, BinaryOperation>::value &&
201+
template <typename _T = Ty, int _Dims = Dims>
202+
enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value &&
204203
sycl::detail::is_geninteger<_T>::value>
205204
operator++() {
206-
static_cast<Reducer *>(this)->combine(static_cast<T>(1));
205+
static_cast<Reducer *>(this)->combine(static_cast<_T>(1));
207206
}
208207

209-
template <typename _T = T, int _Dims = Dims>
210-
enable_if_t<(_Dims == 0) &&
211-
sycl::detail::IsPlus<_T, BinaryOperation>::value &&
208+
template <typename _T = Ty, int _Dims = Dims>
209+
enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value &&
212210
sycl::detail::is_geninteger<_T>::value>
213211
operator++(int) {
214-
static_cast<Reducer *>(this)->combine(static_cast<T>(1));
212+
static_cast<Reducer *>(this)->combine(static_cast<_T>(1));
215213
}
216214

217-
template <typename _T = T, int _Dims = Dims>
218-
enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOperation>::value>
215+
template <typename _T = Ty, int _Dims = Dims>
216+
enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value>
219217
operator+=(const _T &Partial) {
220218
static_cast<Reducer *>(this)->combine(Partial);
221219
}
222220

223-
template <typename _T = T, int _Dims = Dims>
224-
enable_if_t<(_Dims == 0) &&
225-
sycl::detail::IsMultiplies<_T, BinaryOperation>::value>
221+
template <typename _T = Ty, int _Dims = Dims>
222+
enable_if_t<(_Dims == 0) && sycl::detail::IsMultiplies<_T, BinaryOp>::value>
226223
operator*=(const _T &Partial) {
227224
static_cast<Reducer *>(this)->combine(Partial);
228225
}
229226

230-
template <typename _T = T, int _Dims = Dims>
231-
enable_if_t<(_Dims == 0) && sycl::detail::IsBitOR<_T, BinaryOperation>::value>
227+
template <typename _T = Ty, int _Dims = Dims>
228+
enable_if_t<(_Dims == 0) && sycl::detail::IsBitOR<_T, BinaryOp>::value>
232229
operator|=(const _T &Partial) {
233230
static_cast<Reducer *>(this)->combine(Partial);
234231
}
235232

236-
template <typename _T = T, int _Dims = Dims>
237-
enable_if_t<(_Dims == 0) &&
238-
sycl::detail::IsBitXOR<_T, BinaryOperation>::value>
233+
template <typename _T = Ty, int _Dims = Dims>
234+
enable_if_t<(_Dims == 0) && sycl::detail::IsBitXOR<_T, BinaryOp>::value>
239235
operator^=(const _T &Partial) {
240236
static_cast<Reducer *>(this)->combine(Partial);
241237
}
242238

243-
template <typename _T = T, int _Dims = Dims>
244-
enable_if_t<(_Dims == 0) &&
245-
sycl::detail::IsBitAND<_T, BinaryOperation>::value>
239+
template <typename _T = Ty, int _Dims = Dims>
240+
enable_if_t<(_Dims == 0) && sycl::detail::IsBitAND<_T, BinaryOp>::value>
246241
operator&=(const _T &Partial) {
247242
static_cast<Reducer *>(this)->combine(Partial);
248243
}
@@ -266,53 +261,53 @@ template <class Reducer> class combiner {
266261
}
267262
}
268263

269-
template <class _T, access::address_space Space, class BinaryOperation>
264+
template <class _T, access::address_space Space, class BinaryOp>
270265
static constexpr bool BasicCheck =
271-
std::is_same<typename remove_AS<_T>::type, T>::value &&
266+
std::is_same<typename remove_AS<_T>::type, Ty>::value &&
272267
(Space == access::address_space::global_space ||
273268
Space == access::address_space::local_space);
274269

275270
public:
276271
/// Atomic ADD operation: *ReduVarPtr += MValue;
277272
template <access::address_space Space = access::address_space::global_space,
278-
typename _T = T, class _BinaryOperation = BinaryOperation>
273+
typename _T = Ty, class _BinaryOperation = BinaryOp>
279274
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
280-
(IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value ||
281-
IsReduOptForAtomic64Op<T, _BinaryOperation>::value) &&
282-
sycl::detail::IsPlus<T, _BinaryOperation>::value>
275+
(IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value ||
276+
IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) &&
277+
sycl::detail::IsPlus<_T, _BinaryOperation>::value>
283278
atomic_combine(_T *ReduVarPtr) const {
284279
atomic_combine_impl<Space>(
285280
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_add(Val); });
286281
}
287282

288283
/// Atomic BITWISE OR operation: *ReduVarPtr |= MValue;
289284
template <access::address_space Space = access::address_space::global_space,
290-
typename _T = T, class _BinaryOperation = BinaryOperation>
285+
typename _T = Ty, class _BinaryOperation = BinaryOp>
291286
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
292-
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
293-
sycl::detail::IsBitOR<T, _BinaryOperation>::value>
287+
IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value &&
288+
sycl::detail::IsBitOR<_T, _BinaryOperation>::value>
294289
atomic_combine(_T *ReduVarPtr) const {
295290
atomic_combine_impl<Space>(
296291
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_or(Val); });
297292
}
298293

299294
/// Atomic BITWISE XOR operation: *ReduVarPtr ^= MValue;
300295
template <access::address_space Space = access::address_space::global_space,
301-
typename _T = T, class _BinaryOperation = BinaryOperation>
296+
typename _T = Ty, class _BinaryOperation = BinaryOp>
302297
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
303-
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
304-
sycl::detail::IsBitXOR<T, _BinaryOperation>::value>
298+
IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value &&
299+
sycl::detail::IsBitXOR<_T, _BinaryOperation>::value>
305300
atomic_combine(_T *ReduVarPtr) const {
306301
atomic_combine_impl<Space>(
307302
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_xor(Val); });
308303
}
309304

310305
/// Atomic BITWISE AND operation: *ReduVarPtr &= MValue;
311306
template <access::address_space Space = access::address_space::global_space,
312-
typename _T = T, class _BinaryOperation = BinaryOperation>
313-
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
314-
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
315-
sycl::detail::IsBitAND<T, _BinaryOperation>::value &&
307+
typename _T = Ty, class _BinaryOperation = BinaryOp>
308+
enable_if_t<std::is_same<typename remove_AS<_T>::type, _T>::value &&
309+
IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value &&
310+
sycl::detail::IsBitAND<_T, _BinaryOperation>::value &&
316311
(Space == access::address_space::global_space ||
317312
Space == access::address_space::local_space)>
318313
atomic_combine(_T *ReduVarPtr) const {
@@ -322,23 +317,23 @@ template <class Reducer> class combiner {
322317

323318
/// Atomic MIN operation: *ReduVarPtr = sycl::minimum(*ReduVarPtr, MValue);
324319
template <access::address_space Space = access::address_space::global_space,
325-
typename _T = T, class _BinaryOperation = BinaryOperation>
320+
typename _T = Ty, class _BinaryOperation = BinaryOp>
326321
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
327-
(IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value ||
328-
IsReduOptForAtomic64Op<T, _BinaryOperation>::value) &&
329-
sycl::detail::IsMinimum<T, _BinaryOperation>::value>
322+
(IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value ||
323+
IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) &&
324+
sycl::detail::IsMinimum<_T, _BinaryOperation>::value>
330325
atomic_combine(_T *ReduVarPtr) const {
331326
atomic_combine_impl<Space>(
332327
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_min(Val); });
333328
}
334329

335330
/// Atomic MAX operation: *ReduVarPtr = sycl::maximum(*ReduVarPtr, MValue);
336331
template <access::address_space Space = access::address_space::global_space,
337-
typename _T = T, class _BinaryOperation = BinaryOperation>
332+
typename _T = Ty, class _BinaryOperation = BinaryOp>
338333
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
339-
(IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value ||
340-
IsReduOptForAtomic64Op<T, _BinaryOperation>::value) &&
341-
sycl::detail::IsMaximum<T, _BinaryOperation>::value>
334+
(IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value ||
335+
IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) &&
336+
sycl::detail::IsMaximum<_T, _BinaryOperation>::value>
342337
atomic_combine(_T *ReduVarPtr) const {
343338
atomic_combine_impl<Space>(
344339
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_max(Val); });
@@ -928,7 +923,7 @@ bool reduCGFuncForRangeFastAtomics(handler &CGH, KernelType KernelFunc,
928923
const range<Dims> &Range,
929924
const nd_range<1> &NDRange,
930925
Reduction &Redu) {
931-
constexpr size_t NElements = Reduction::num_elements;
926+
size_t NElements = Reduction::num_elements;
932927
auto Out = Redu.getReadWriteAccessorToInitializedMem(CGH);
933928
auto GroupSum = Reduction::getReadWriteLocalAcc(NElements, CGH);
934929
using Name = __sycl_reduction_kernel<reduction::main_krn::RangeFastAtomics,
@@ -976,7 +971,7 @@ template <typename KernelName, typename KernelType, int Dims, class Reduction>
976971
bool reduCGFuncForRangeFastReduce(handler &CGH, KernelType KernelFunc,
977972
const range<Dims> &Range,
978973
const nd_range<1> &NDRange, Reduction &Redu) {
979-
constexpr size_t NElements = Reduction::num_elements;
974+
size_t NElements = Reduction::num_elements;
980975
size_t WGSize = NDRange.get_local_range().size();
981976
size_t NWorkGroups = NDRange.get_group_range().size();
982977

@@ -1078,7 +1073,7 @@ template <typename KernelName, typename KernelType, int Dims, class Reduction>
10781073
bool reduCGFuncForRangeBasic(handler &CGH, KernelType KernelFunc,
10791074
const range<Dims> &Range,
10801075
const nd_range<1> &NDRange, Reduction &Redu) {
1081-
constexpr size_t NElements = Reduction::num_elements;
1076+
size_t NElements = Reduction::num_elements;
10821077
size_t WGSize = NDRange.get_local_range().size();
10831078
size_t NWorkGroups = NDRange.get_group_range().size();
10841079

@@ -1230,7 +1225,7 @@ template <typename KernelName, typename KernelType, int Dims, class Reduction>
12301225
void reduCGFuncForNDRangeBothFastReduceAndAtomics(
12311226
handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
12321227
Reduction &, typename Reduction::rw_accessor_type Out) {
1233-
constexpr size_t NElements = Reduction::num_elements;
1228+
size_t NElements = Reduction::num_elements;
12341229
using Name = __sycl_reduction_kernel<
12351230
reduction::main_krn::NDRangeBothFastReduceAndAtomics, KernelName>;
12361231
CGH.parallel_for<Name>(Range, [=](nd_item<Dims> NDIt) {
@@ -1266,7 +1261,7 @@ void reduCGFuncForNDRangeFastAtomicsOnly(
12661261
handler &CGH, bool IsPow2WG, KernelType KernelFunc,
12671262
const nd_range<Dims> &Range, Reduction &,
12681263
typename Reduction::rw_accessor_type Out) {
1269-
constexpr size_t NElements = Reduction::num_elements;
1264+
size_t NElements = Reduction::num_elements;
12701265
size_t WGSize = Range.get_local_range().size();
12711266

12721267
// Use local memory to reduce elements in work-groups into zero-th element.
@@ -1345,7 +1340,7 @@ template <typename KernelName, typename KernelType, int Dims, class Reduction>
13451340
void reduCGFuncForNDRangeFastReduceOnly(
13461341
handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
13471342
Reduction &Redu, typename Reduction::rw_accessor_type Out) {
1348-
constexpr size_t NElements = Reduction::num_elements;
1343+
size_t NElements = Reduction::num_elements;
13491344
size_t NWorkGroups = Range.get_group_range().size();
13501345
bool IsUpdateOfUserVar =
13511346
!Reduction::is_usm && !Redu.initializeToIdentity() && NWorkGroups == 1;
@@ -1392,7 +1387,7 @@ void reduCGFuncForNDRangeBasic(handler &CGH, bool IsPow2WG,
13921387
KernelType KernelFunc,
13931388
const nd_range<Dims> &Range, Reduction &Redu,
13941389
typename Reduction::rw_accessor_type Out) {
1395-
constexpr size_t NElements = Reduction::num_elements;
1390+
size_t NElements = Reduction::num_elements;
13961391
size_t WGSize = Range.get_local_range().size();
13971392
size_t NWorkGroups = Range.get_group_range().size();
13981393

@@ -1477,7 +1472,7 @@ void reduAuxCGFuncFastReduceImpl(handler &CGH, bool UniformWG,
14771472
size_t NWorkItems, size_t NWorkGroups,
14781473
size_t WGSize, Reduction &Redu, InputT In,
14791474
OutputT Out) {
1480-
constexpr size_t NElements = Reduction::num_elements;
1475+
size_t NElements = Reduction::num_elements;
14811476
using Name =
14821477
__sycl_reduction_kernel<reduction::aux_krn::FastReduce, KernelName>;
14831478
bool IsUpdateOfUserVar =
@@ -1523,7 +1518,7 @@ void reduAuxCGFuncNoFastReduceNorAtomicImpl(handler &CGH, bool UniformPow2WG,
15231518
size_t NWorkGroups, size_t WGSize,
15241519
Reduction &Redu, InputT In,
15251520
OutputT Out) {
1526-
constexpr size_t NElements = Reduction::num_elements;
1521+
size_t NElements = Reduction::num_elements;
15271522
bool IsUpdateOfUserVar =
15281523
!Reduction::is_usm && !Redu.initializeToIdentity() && NWorkGroups == 1;
15291524

@@ -1642,7 +1637,7 @@ reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
16421637
template <typename KernelName, class Reduction>
16431638
std::enable_if_t<Reduction::is_usm>
16441639
reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
1645-
constexpr size_t NElements = Reduction::num_elements;
1640+
size_t NElements = Reduction::num_elements;
16461641
auto InAcc = Redu.getReadAccToPreviousPartialReds(CGH);
16471642
auto UserVarPtr = Redu.getUserRedVar();
16481643
bool IsUpdateOfUserVar = !Redu.initializeToIdentity();
@@ -2120,7 +2115,7 @@ void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc,
21202115
static_assert(
21212116
Reduction::has_float64_atomics,
21222117
"Only suitable for reductions that have FP64 atomic operations.");
2123-
constexpr size_t NElements = Reduction::num_elements;
2118+
size_t NElements = Reduction::num_elements;
21242119
using Name =
21252120
__sycl_reduction_kernel<reduction::main_krn::NDRangeAtomic64, KernelName>;
21262121
CGH.parallel_for<Name>(Range, [=](nd_item<Dims> NDIt) {

0 commit comments

Comments
 (0)