Skip to content

Commit a41caf6

Browse files
committed
Added branch compatible with variadic tuple version in Trust (since CUDA 12.4).
1 parent e46ba34 commit a41caf6

File tree

2 files changed

+81
-9
lines changed

2 files changed

+81
-9
lines changed

modules/cudev/include/opencv2/cudev/block/detail/reduce.hpp

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,17 @@ namespace block_reduce_detail
154154
val = smem[tid];
155155
}
156156

157+
158+
// merge
159+
160+
template <typename T, class Op>
161+
__device__ __forceinline__ void merge(volatile T* smem, T& val, uint tid, uint delta, const Op& op)
162+
{
163+
T reg = smem[tid + delta];
164+
smem[tid] = val = op(val, reg);
165+
}
166+
167+
#if (CUDART_VERSION < 12040)
157168
template <typename P0, typename P1, typename P2, typename P3, typename P4, typename P5, typename P6, typename P7, typename P8, typename P9,
158169
typename R0, typename R1, typename R2, typename R3, typename R4, typename R5, typename R6, typename R7, typename R8, typename R9>
159170
__device__ __forceinline__ void loadToSmem(const tuple<P0, P1, P2, P3, P4, P5, P6, P7, P8, P9>& smem,
@@ -172,15 +183,6 @@ namespace block_reduce_detail
172183
For<0, tuple_size<tuple<P0, P1, P2, P3, P4, P5, P6, P7, P8, P9> >::value>::loadFromSmem(smem, val, tid);
173184
}
174185

175-
// merge
176-
177-
template <typename T, class Op>
178-
__device__ __forceinline__ void merge(volatile T* smem, T& val, uint tid, uint delta, const Op& op)
179-
{
180-
T reg = smem[tid + delta];
181-
smem[tid] = val = op(val, reg);
182-
}
183-
184186
template <typename P0, typename P1, typename P2, typename P3, typename P4, typename P5, typename P6, typename P7, typename P8, typename P9,
185187
typename R0, typename R1, typename R2, typename R3, typename R4, typename R5, typename R6, typename R7, typename R8, typename R9,
186188
class Op0, class Op1, class Op2, class Op3, class Op4, class Op5, class Op6, class Op7, class Op8, class Op9>
@@ -214,6 +216,41 @@ namespace block_reduce_detail
214216
}
215217
#endif
216218

219+
#else
220+
template <typename... P, typename... R>
221+
__device__ __forceinline__ void loadToSmem(const tuple<P...>& smem, const tuple<R...>& val, uint tid)
222+
{
223+
For<0, tuple_size<tuple<P...> >::value>::loadToSmem(smem, val, tid);
224+
}
225+
226+
template <typename... P, typename... R>
227+
__device__ __forceinline__ void loadFromSmem(const tuple<P...>& smem, const tuple<R...>& val, uint tid)
228+
{
229+
For<0, tuple_size<tuple<P...> >::value>::loadFromSmem(smem, val, tid);
230+
}
231+
232+
template <typename P..., typename... R, class... Op>
233+
__device__ __forceinline__ void merge(const tuple<P...>& smem, const tuple<R...>& val, uint tid, uint delta, const tuple<Op...>& op)
234+
{
235+
For<0, tuple_size<tuple<P...> >::value>::merge(smem, val, tid, delta, op);
236+
}
237+
238+
// mergeShfl
239+
240+
template <typename T, class Op>
241+
__device__ __forceinline__ void mergeShfl(T& val, uint delta, uint width, const Op& op)
242+
{
243+
T reg = shfl_down(val, delta, width);
244+
val = op(val, reg);
245+
}
246+
247+
template <typename... R, class... Op>
248+
__device__ __forceinline__ void mergeShfl(const tuple<R...>& val, uint delta, uint width, const tuple<Op...>& op)
249+
{
250+
For<0, tuple_size<tuple<R...> >::value>::mergeShfl(val, delta, width, op);
251+
}
252+
#endif
253+
217254
// Generic
218255

219256
template <int N> struct Generic

modules/cudev/include/opencv2/cudev/block/reduce.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "../warp/reduce.hpp"
5252
#include "detail/reduce.hpp"
5353
#include "detail/reduce_key_val.hpp"
54+
#include <cuda_runtime_api.h>
5455

5556
namespace cv { namespace cudev {
5657

@@ -65,6 +66,7 @@ __device__ __forceinline__ void blockReduce(volatile T* smem, T& val, uint tid,
6566
block_reduce_detail::Dispatcher<N>::reductor::template reduce<volatile T*, T&, const Op&>(smem, val, tid, op);
6667
}
6768

69+
#if (CUDART_VERSION < 12040)
6870
template <int N,
6971
typename P0, typename P1, typename P2, typename P3, typename P4, typename P5, typename P6, typename P7, typename P8, typename P9,
7072
typename R0, typename R1, typename R2, typename R3, typename R4, typename R5, typename R6, typename R7, typename R8, typename R9,
@@ -126,6 +128,39 @@ __device__ __forceinline__ void blockReduceKeyVal(const tuple<KP0, KP1, KP2, KP3
126128
>(skeys, key, svals, val, tid, cmp);
127129
}
128130

131+
#else
132+
133+
template <int N, typename... P, typename... R, typename... Op>
134+
__device__ __forceinline__ void blockReduce(const tuple<P...>& smem,
135+
const tuple<R...>& val,
136+
uint tid,
137+
const tuple<Op..>& op)
138+
{
139+
block_reduce_detail::Dispatcher<N>::reductor::template reduce<const tuple<P...>&, const tuple<R...>&, const tuple<Op...>&>(smem, val, tid, op);
140+
}
141+
142+
// blockReduceKeyVal
143+
144+
template <int N, typename K, typename V, class Cmp>
145+
__device__ __forceinline__ void blockReduceKeyVal(volatile K* skeys, K& key, volatile V* svals, V& val, uint tid, const Cmp& cmp)
146+
{
147+
block_reduce_key_val_detail::Dispatcher<N>::reductor::template reduce<volatile K*, K&, volatile V*, V&, const Cmp&>(skeys, key, svals, val, tid, cmp);
148+
}
149+
150+
template <int N, typename K, typename... VP, typename... VR, class Cmp>
151+
__device__ __forceinline__ void blockReduceKeyVal(volatile K* skeys, K& key, const tuple<VP...>& svals, const tuple<VR...>& val, uint tid, const Cmp& cmp)
152+
{
153+
block_reduce_key_val_detail::Dispatcher<N>::reductor::template reduce<volatile K*, K&, const tuple<VP...>&, const tuple<VR...>&, const Cmp&>(skeys, key, svals, val, tid, cmp);
154+
}
155+
156+
template <int N, typename... KP, typename... KR, typename... VP, typename... VR, class... Cmp>
157+
__device__ __forceinline__ void blockReduceKeyVal(const tuple<KP...>& skeys, const tuple<KR...>& key, const tuple<VP...>& svals, const tuple<VR...>& val, uint tid, const tuple<Cmp...>& cmp)
158+
{
159+
block_reduce_key_val_detail::Dispatcher<N>::reductor::template reduce< const tuple<KP...>&, const tuple<KR...>&, const tuple<VP...>&, const tuple<VR...>&, const tuple<Cmp...>&>(skeys, key, svals, val, tid, cmp);
160+
}
161+
162+
#endif
163+
129164
//! @}
130165

131166
}}

0 commit comments

Comments
 (0)