@@ -160,6 +160,7 @@ namespace block_reduce_key_val_detail
160
160
data = smem[tid];
161
161
}
162
162
163
+ #if (CUDART_VERSION < 12040)
163
164
template <typename VP0, typename VP1, typename VP2, typename VP3, typename VP4, typename VP5, typename VP6, typename VP7, typename VP8, typename VP9,
164
165
typename VR0, typename VR1, typename VR2, typename VR3, typename VR4, typename VR5, typename VR6, typename VR7, typename VR8, typename VR9>
165
166
__device__ __forceinline__ void loadToSmem (const tuple<VP0, VP1, VP2, VP3, VP4, VP5, VP6, VP7, VP8, VP9>& smem,
@@ -241,6 +242,67 @@ namespace block_reduce_key_val_detail
241
242
{
242
243
For<0 , tuple_size<tuple<VP0, VP1, VP2, VP3, VP4, VP5, VP6, VP7, VP8, VP9> >::value>::merge (skeys, key, svals, val, cmp, tid, delta);
243
244
}
245
+ #else
246
+ template <typename ... VP, typename ... VR>
247
+ __device__ __forceinline__ void loadToSmem (const tuple<VP...>& smem, const tuple<VR...>& data, uint tid)
248
+ {
249
+ For<0 , tuple_size<tuple<VP...> >::value>::loadToSmem (smem, data, tid);
250
+ }
251
+
252
+ template <typename ... VP, typename ... VR>
253
+ __device__ __forceinline__ void loadFromSmem (const tuple<VP...>& smem, const tuple<VR...>& data, uint tid)
254
+ {
255
+ For<0 , tuple_size<tuple<VP...> >::value>::loadFromSmem (smem, data, tid);
256
+ }
257
+
258
+ // copyVals
259
+
260
+ template <typename V>
261
+ __device__ __forceinline__ void copyVals (volatile V* svals, V& val, uint tid, uint delta)
262
+ {
263
+ svals[tid] = val = svals[tid + delta];
264
+ }
265
+
266
+ template <typename ... VP, typename ... VR>
267
+ __device__ __forceinline__ void copyVals (const tuple<VP...>& svals, const tuple<VR...>& val, uint tid, uint delta)
268
+ {
269
+ For<0 , tuple_size<tuple<VP...> >::value>::copy (svals, val, tid, delta);
270
+ }
271
+
272
+ // merge
273
+
274
+ template <typename K, typename V, class Cmp >
275
+ __device__ void merge (volatile K* skeys, K& key, volatile V* svals, V& val, const Cmp& cmp, uint tid, uint delta)
276
+ {
277
+ K reg = skeys[tid + delta];
278
+
279
+ if (cmp (reg, key))
280
+ {
281
+ skeys[tid] = key = reg;
282
+ copyVals (svals, val, tid, delta);
283
+ }
284
+ }
285
+
286
+ template <typename K, typename ... VP, typename ... VR, class Cmp >
287
+ __device__ void merge (volatile K* skeys, K& key, const tuple<VP...>& svals, const tuple<VR...>& val, const Cmp& cmp, uint tid, uint delta)
288
+ {
289
+ K reg = skeys[tid + delta];
290
+
291
+ if (cmp (reg, key))
292
+ {
293
+ skeys[tid] = key = reg;
294
+ copyVals (svals, val, tid, delta);
295
+ }
296
+ }
297
+
298
+ template <typename ... KP, typename ... KR, typename ... VP, typename ... VR, class ... Cmp>
299
+ __device__ __forceinline__ void merge (const tuple<KP...>& skeys, const tuple<KR...>& key,
300
+ const tuple<VP...>& svals, const tuple<VR...>& val,
301
+ const tuple<Cmp...>& cmp, uint tid, uint delta)
302
+ {
303
+ For<0 , tuple_size<tuple<VP...> >::value>::merge (skeys, key, svals, val, cmp, tid, delta);
304
+ }
305
+ #endif
244
306
245
307
// Generic
246
308
0 commit comments