@@ -154,6 +154,17 @@ namespace block_reduce_detail
154
154
val = smem[tid];
155
155
}
156
156
157
+
158
+ // merge
159
+
160
+ template <typename T, class Op >
161
+ __device__ __forceinline__ void merge (volatile T* smem, T& val, uint tid, uint delta, const Op& op)
162
+ {
163
+ T reg = smem[tid + delta];
164
+ smem[tid] = val = op (val, reg);
165
+ }
166
+
167
+ #if (CUDART_VERSION < 12040)
157
168
template <typename P0, typename P1, typename P2, typename P3, typename P4, typename P5, typename P6, typename P7, typename P8, typename P9,
158
169
typename R0, typename R1, typename R2, typename R3, typename R4, typename R5, typename R6, typename R7, typename R8, typename R9>
159
170
__device__ __forceinline__ void loadToSmem (const tuple<P0, P1, P2, P3, P4, P5, P6, P7, P8, P9>& smem,
@@ -172,15 +183,6 @@ namespace block_reduce_detail
172
183
For<0 , tuple_size<tuple<P0, P1, P2, P3, P4, P5, P6, P7, P8, P9> >::value>::loadFromSmem (smem, val, tid);
173
184
}
174
185
175
- // merge
176
-
177
- template <typename T, class Op >
178
- __device__ __forceinline__ void merge (volatile T* smem, T& val, uint tid, uint delta, const Op& op)
179
- {
180
- T reg = smem[tid + delta];
181
- smem[tid] = val = op (val, reg);
182
- }
183
-
184
186
template <typename P0, typename P1, typename P2, typename P3, typename P4, typename P5, typename P6, typename P7, typename P8, typename P9,
185
187
typename R0, typename R1, typename R2, typename R3, typename R4, typename R5, typename R6, typename R7, typename R8, typename R9,
186
188
class Op0 , class Op1 , class Op2 , class Op3 , class Op4 , class Op5 , class Op6 , class Op7 , class Op8 , class Op9 >
@@ -214,6 +216,41 @@ namespace block_reduce_detail
214
216
}
215
217
#endif
216
218
219
+ #else
220
+ template <typename ... P, typename ... R>
221
+ __device__ __forceinline__ void loadToSmem (const tuple<P...>& smem, const tuple<R...>& val, uint tid)
222
+ {
223
+ For<0 , tuple_size<tuple<P...> >::value>::loadToSmem (smem, val, tid);
224
+ }
225
+
226
+ template <typename ... P, typename ... R>
227
+ __device__ __forceinline__ void loadFromSmem (const tuple<P...>& smem, const tuple<R...>& val, uint tid)
228
+ {
229
+ For<0 , tuple_size<tuple<P...> >::value>::loadFromSmem (smem, val, tid);
230
+ }
231
+
232
+ template <typename P..., typename ... R, class ... Op>
233
+ __device__ __forceinline__ void merge (const tuple<P...>& smem, const tuple<R...>& val, uint tid, uint delta, const tuple<Op...>& op)
234
+ {
235
+ For<0 , tuple_size<tuple<P...> >::value>::merge (smem, val, tid, delta, op);
236
+ }
237
+
238
+ // mergeShfl
239
+
240
+ template <typename T, class Op >
241
+ __device__ __forceinline__ void mergeShfl (T& val, uint delta, uint width, const Op& op)
242
+ {
243
+ T reg = shfl_down (val, delta, width);
244
+ val = op (val, reg);
245
+ }
246
+
247
+ template <typename ... R, class ... Op>
248
+ __device__ __forceinline__ void mergeShfl (const tuple<R...>& val, uint delta, uint width, const tuple<Op...>& op)
249
+ {
250
+ For<0 , tuple_size<tuple<R...> >::value>::mergeShfl (val, delta, width, op);
251
+ }
252
+ #endif
253
+
217
254
// Generic
218
255
219
256
template <int N> struct Generic
0 commit comments