Skip to content

Commit 2475844

Browse files
[NFC][SYCL][Reduction] Minor refactoring (#6318)
1 parent 5352b42 commit 2475844

File tree

2 files changed

+41
-59
lines changed

2 files changed

+41
-59
lines changed

sycl/include/CL/sycl/handler.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2688,7 +2688,8 @@ class __SYCL_EXPORT handler {
26882688
class Algorithm>
26892689
friend class ext::oneapi::detail::reduction_impl_algo;
26902690

2691-
// This method needs to call the method finalize().
2691+
// This method needs to call the method finalize() and also access to private
2692+
// ctor/dtor.
26922693
template <typename Reduction, typename... RestT>
26932694
std::enable_if_t<!Reduction::is_usm> friend ext::oneapi::detail::
26942695
reduSaveFinalResultToUserMemHelper(

sycl/include/sycl/ext/oneapi/reduction.hpp

Lines changed: 39 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -214,57 +214,56 @@ template <class Reducer> class combiner {
214214
: memory_scope::device;
215215
}
216216

217+
template <access::address_space Space, class T, class AtomicFunctor>
218+
void atomic_combine_impl(T *ReduVarPtr, AtomicFunctor Functor) const {
219+
auto reducer = static_cast<const Reducer *>(this);
220+
for (size_t E = 0; E < Extent; ++E) {
221+
auto AtomicRef =
222+
atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
223+
multi_ptr<T, Space>(ReduVarPtr)[E]);
224+
Functor(AtomicRef, reducer->getElement(E));
225+
}
226+
}
227+
228+
template <class _T, access::address_space Space, class BinaryOperation>
229+
static inline constexpr bool BasicCheck =
230+
std::is_same<typename remove_AS<_T>::type, T>::value &&
231+
(Space == access::address_space::global_space ||
232+
Space == access::address_space::local_space);
233+
217234
public:
218235
/// Atomic ADD operation: *ReduVarPtr += MValue;
219236
template <access::address_space Space = access::address_space::global_space,
220237
typename _T = T, class _BinaryOperation = BinaryOperation>
221-
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
238+
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
222239
(IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value ||
223240
IsReduOptForAtomic64Add<T, _BinaryOperation>::value) &&
224-
sycl::detail::IsPlus<T, _BinaryOperation>::value &&
225-
(Space == access::address_space::global_space ||
226-
Space == access::address_space::local_space)>
241+
sycl::detail::IsPlus<T, _BinaryOperation>::value>
227242
atomic_combine(_T *ReduVarPtr) const {
228-
auto reducer = static_cast<const Reducer *>(this);
229-
for (size_t E = 0; E < Extent; ++E) {
230-
atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
231-
multi_ptr<T, Space>(ReduVarPtr)[E])
232-
.fetch_add(reducer->getElement(E));
233-
}
243+
atomic_combine_impl<Space>(
244+
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_add(Val); });
234245
}
235246

236247
/// Atomic BITWISE OR operation: *ReduVarPtr |= MValue;
237248
template <access::address_space Space = access::address_space::global_space,
238249
typename _T = T, class _BinaryOperation = BinaryOperation>
239-
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
250+
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
240251
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
241-
sycl::detail::IsBitOR<T, _BinaryOperation>::value &&
242-
(Space == access::address_space::global_space ||
243-
Space == access::address_space::local_space)>
252+
sycl::detail::IsBitOR<T, _BinaryOperation>::value>
244253
atomic_combine(_T *ReduVarPtr) const {
245-
auto reducer = static_cast<const Reducer *>(this);
246-
for (size_t E = 0; E < Extent; ++E) {
247-
atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
248-
multi_ptr<T, Space>(ReduVarPtr)[E])
249-
.fetch_or(reducer->getElement(E));
250-
}
254+
atomic_combine_impl<Space>(
255+
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_or(Val); });
251256
}
252257

253258
/// Atomic BITWISE XOR operation: *ReduVarPtr ^= MValue;
254259
template <access::address_space Space = access::address_space::global_space,
255260
typename _T = T, class _BinaryOperation = BinaryOperation>
256-
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
261+
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
257262
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
258-
sycl::detail::IsBitXOR<T, _BinaryOperation>::value &&
259-
(Space == access::address_space::global_space ||
260-
Space == access::address_space::local_space)>
263+
sycl::detail::IsBitXOR<T, _BinaryOperation>::value>
261264
atomic_combine(_T *ReduVarPtr) const {
262-
auto reducer = static_cast<const Reducer *>(this);
263-
for (size_t E = 0; E < Extent; ++E) {
264-
atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
265-
multi_ptr<T, Space>(ReduVarPtr)[E])
266-
.fetch_xor(reducer->getElement(E));
267-
}
265+
atomic_combine_impl<Space>(
266+
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_xor(Val); });
268267
}
269268

270269
/// Atomic BITWISE AND operation: *ReduVarPtr &= MValue;
@@ -276,46 +275,30 @@ template <class Reducer> class combiner {
276275
(Space == access::address_space::global_space ||
277276
Space == access::address_space::local_space)>
278277
atomic_combine(_T *ReduVarPtr) const {
279-
auto reducer = static_cast<const Reducer *>(this);
280-
for (size_t E = 0; E < Extent; ++E) {
281-
atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
282-
multi_ptr<T, Space>(ReduVarPtr)[E])
283-
.fetch_and(reducer->getElement(E));
284-
}
278+
atomic_combine_impl<Space>(
279+
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_and(Val); });
285280
}
286281

287282
/// Atomic MIN operation: *ReduVarPtr = sycl::minimum(*ReduVarPtr, MValue);
288283
template <access::address_space Space = access::address_space::global_space,
289284
typename _T = T, class _BinaryOperation = BinaryOperation>
290-
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
285+
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
291286
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
292-
sycl::detail::IsMinimum<T, _BinaryOperation>::value &&
293-
(Space == access::address_space::global_space ||
294-
Space == access::address_space::local_space)>
287+
sycl::detail::IsMinimum<T, _BinaryOperation>::value>
295288
atomic_combine(_T *ReduVarPtr) const {
296-
auto reducer = static_cast<const Reducer *>(this);
297-
for (size_t E = 0; E < Extent; ++E) {
298-
atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
299-
multi_ptr<T, Space>(ReduVarPtr)[E])
300-
.fetch_min(reducer->getElement(E));
301-
}
289+
atomic_combine_impl<Space>(
290+
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_min(Val); });
302291
}
303292

304293
/// Atomic MAX operation: *ReduVarPtr = sycl::maximum(*ReduVarPtr, MValue);
305294
template <access::address_space Space = access::address_space::global_space,
306295
typename _T = T, class _BinaryOperation = BinaryOperation>
307-
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
296+
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
308297
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
309-
sycl::detail::IsMaximum<T, _BinaryOperation>::value &&
310-
(Space == access::address_space::global_space ||
311-
Space == access::address_space::local_space)>
298+
sycl::detail::IsMaximum<T, _BinaryOperation>::value>
312299
atomic_combine(_T *ReduVarPtr) const {
313-
auto reducer = static_cast<const Reducer *>(this);
314-
for (size_t E = 0; E < Extent; ++E) {
315-
atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
316-
multi_ptr<T, Space>(ReduVarPtr)[E])
317-
.fetch_max(reducer->getElement(E));
318-
}
300+
atomic_combine_impl<Space>(
301+
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_max(Val); });
319302
}
320303
};
321304

@@ -415,8 +398,6 @@ class reducer<T, BinaryOperation, Dims, Extent, Algorithm, View,
415398
reducer(const T &Identity, BinaryOperation BOp)
416399
: MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
417400

418-
// SYCL 2020 revision 4 says this should be const, but this is a bug
419-
// see https://github.com/KhronosGroup/SYCL-Docs/pull/252
420401
reducer<T, BinaryOperation, Dims - 1, Extent, Algorithm, true>
421402
operator[](size_t Index) {
422403
return {MValue[Index], MBinaryOp};

0 commit comments

Comments
 (0)