Skip to content

Commit cc66eb6

Browse files
Change streaming algorithms to use operator+= from using operator+ (NVIDIA#4428)
1 parent 18cebdf commit cc66eb6

File tree

4 files changed

+133
-58
lines changed

4 files changed

+133
-58
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#pragma once
5+
6+
#include <cub/config.cuh>
7+
8+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
9+
# pragma GCC system_header
10+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
11+
# pragma clang system_header
12+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
13+
# pragma system_header
14+
#endif // no system header
15+
16+
#include <cuda/std/type_traits>
17+
18+
CUB_NAMESPACE_BEGIN
19+
20+
namespace detail
21+
{
22+
template <typename T, typename U, typename = void>
23+
struct has_plus_operator : ::cuda::std::false_type
24+
{};
25+
26+
template <typename T, typename U>
27+
struct has_plus_operator<T, U, ::cuda::std::void_t<decltype(::cuda::std::declval<T>() + ::cuda::std::declval<U>())>>
28+
: ::cuda::std::true_type
29+
{};
30+
31+
template <typename T, typename U>
32+
constexpr bool has_plus_operator_v = has_plus_operator<T, U>::value;
33+
34+
// Helper function that advances a given iterator only if it supports being advanced by the given offset
35+
template <typename IteratorT, typename OffsetT>
36+
CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE IteratorT
37+
advance_iterators_if_supported(IteratorT iter, [[maybe_unused]] OffsetT offset)
38+
{
39+
if constexpr (has_plus_operator_v<IteratorT, OffsetT>)
40+
{
41+
// If operator+ is valid, advance the iterator.
42+
return iter + offset;
43+
}
44+
else
45+
{
46+
// Otherwise, return iter unmodified.
47+
return iter;
48+
}
49+
}
50+
51+
template <typename T, typename U, typename = void>
52+
struct has_add_assign_operator : ::cuda::std::false_type
53+
{};
54+
55+
template <typename T, typename U>
56+
struct has_add_assign_operator<T,
57+
U,
58+
::cuda::std::void_t<decltype(::cuda::std::declval<T&>() += ::cuda::std::declval<U>())>>
59+
: ::cuda::std::true_type
60+
{};
61+
62+
template <typename T, typename U>
63+
constexpr bool has_add_assign_operator_v = has_add_assign_operator<T, U>::value;
64+
65+
// Helper function that advances a given iterator only if it supports being advanced by the given offset
66+
template <typename IteratorT, typename OffsetT>
67+
CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE void
68+
advance_iterators_inplace_if_supported(IteratorT& iter, [[maybe_unused]] OffsetT offset)
69+
{
70+
if constexpr (has_add_assign_operator_v<IteratorT, OffsetT>)
71+
{
72+
// If operator+ is valid, advance the iterator.
73+
iter += offset;
74+
}
75+
}
76+
77+
// Helper function that checks whether all of the given iterators support the + operator with the given offset
78+
template <typename OffsetT, typename... Iterators>
79+
CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE bool
80+
all_iterators_support_plus_operator(OffsetT /*offset*/, Iterators... /*iters*/)
81+
{
82+
if constexpr ((has_plus_operator_v<Iterators, OffsetT> && ...))
83+
{
84+
return true;
85+
}
86+
else
87+
{
88+
return false;
89+
}
90+
}
91+
92+
// Helper function that checks whether all of the given iterators support the + operator with the given offset
93+
template <typename OffsetT, typename... Iterators>
94+
CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE bool
95+
all_iterators_support_add_assign_operator(OffsetT /*offset*/, Iterators... /*iters*/)
96+
{
97+
if constexpr ((has_add_assign_operator_v<Iterators, OffsetT> && ...))
98+
{
99+
return true;
100+
}
101+
else
102+
{
103+
return false;
104+
}
105+
}
106+
107+
} // namespace detail
108+
109+
CUB_NAMESPACE_END

cub/cub/device/dispatch/dispatch_common.cuh

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -42,50 +42,4 @@ enum class SelectImpl
4242
Partition
4343
};
4444

45-
namespace detail
46-
{
47-
template <typename T, typename U, typename = void>
48-
struct has_plus_operator : ::cuda::std::false_type
49-
{};
50-
51-
template <typename T, typename U>
52-
struct has_plus_operator<T, U, ::cuda::std::void_t<decltype(::cuda::std::declval<T>() + ::cuda::std::declval<U>())>>
53-
: ::cuda::std::true_type
54-
{};
55-
56-
template <typename T, typename U>
57-
constexpr bool has_plus_operator_v = has_plus_operator<T, U>::value;
58-
59-
// Helper function that advances a given iterator only if it supports being advanced by the given offset
60-
template <typename IteratorT, typename OffsetT>
61-
_CCCL_HOST_DEVICE IteratorT advance_iterators_if_supported(IteratorT iter, [[maybe_unused]] OffsetT offset)
62-
{
63-
if constexpr (has_plus_operator_v<IteratorT, OffsetT>)
64-
{
65-
// If operator+ is valid, advance the iterator.
66-
return iter + offset;
67-
}
68-
else
69-
{
70-
// Otherwise, return iter unmodified.
71-
return iter;
72-
}
73-
}
74-
75-
// Helper function that checks whether all of the given iterators support the + operator with the given offset
76-
template <typename OffsetT, typename... Iterators>
77-
_CCCL_HOST_DEVICE bool all_iterators_support_plus_operator(OffsetT /*offset*/, Iterators... /*iters*/)
78-
{
79-
if constexpr ((has_plus_operator_v<Iterators, OffsetT> && ...))
80-
{
81-
return true;
82-
}
83-
else
84-
{
85-
return false;
86-
}
87-
}
88-
89-
} // namespace detail
90-
9145
CUB_NAMESPACE_END

cub/cub/device/dispatch/dispatch_radix_sort.cuh

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
# pragma system_header
4545
#endif // no system header
4646

47+
#include <cub/device/dispatch/dispatch_advance_iterators.cuh>
4748
#include <cub/device/dispatch/kernels/radix_sort.cuh>
4849
#include <cub/device/dispatch/tuning/tuning_radix_sort.cuh>
4950
#include <cub/util_debug.cuh>
@@ -1405,11 +1406,14 @@ struct DispatchSegmentedRadixSort
14051406
// If d_begin_offsets and d_end_offsets do not support operator+ then we can't have more than
14061407
// max_num_segments_per_invocation segments per invocation
14071408
if (num_invocations > 1
1408-
&& !detail::all_iterators_support_plus_operator(::cuda::std::int64_t{}, d_begin_offsets, d_end_offsets))
1409+
&& !detail::all_iterators_support_add_assign_operator(::cuda::std::int64_t{}, d_begin_offsets, d_end_offsets))
14091410
{
14101411
return cudaErrorInvalidValue;
14111412
}
14121413

1414+
BeginOffsetIteratorT begin_offsets_current_it = d_begin_offsets;
1415+
EndOffsetIteratorT end_offsets_current_it = d_end_offsets;
1416+
14131417
// Iterate over chunks of segments
14141418
for (::cuda::std::int64_t invocation_index = 0; invocation_index < num_invocations; invocation_index++)
14151419
{
@@ -1440,8 +1444,8 @@ struct DispatchSegmentedRadixSort
14401444
d_keys_out,
14411445
d_values_in,
14421446
d_values_out,
1443-
detail::advance_iterators_if_supported(d_begin_offsets, current_segment_offset),
1444-
detail::advance_iterators_if_supported(d_end_offsets, current_segment_offset),
1447+
begin_offsets_current_it,
1448+
end_offsets_current_it,
14451449
current_bit,
14461450
pass_bits,
14471451
decomposer);
@@ -1453,6 +1457,12 @@ struct DispatchSegmentedRadixSort
14531457
return error;
14541458
}
14551459

1460+
if (invocation_index + 1 < num_invocations)
1461+
{
1462+
detail::advance_iterators_inplace_if_supported(begin_offsets_current_it, num_current_segments);
1463+
detail::advance_iterators_inplace_if_supported(end_offsets_current_it, num_current_segments);
1464+
}
1465+
14561466
// Sync the stream if specified to flush runtime errors
14571467
error = CubDebug(detail::DebugSyncStream(stream));
14581468
if (cudaSuccess != error)

cub/cub/device/dispatch/dispatch_reduce.cuh

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
#include <cub/detail/launcher/cuda_runtime.cuh>
4848
#include <cub/detail/type_traits.cuh> // for cub::detail::invoke_result_t
49-
#include <cub/device/dispatch/dispatch_common.cuh>
49+
#include <cub/device/dispatch/dispatch_advance_iterators.cuh>
5050
#include <cub/device/dispatch/kernels/reduce.cuh>
5151
#include <cub/device/dispatch/kernels/segmented_reduce.cuh>
5252
#include <cub/device/dispatch/tuning/tuning_reduce.cuh>
@@ -823,7 +823,8 @@ struct DispatchSegmentedReduce
823823
// indirect_arg_t as the iterator type, which does not support the + operator.
824824
// TODO (elstehle): Remove this check once https://github.com/NVIDIA/cccl/issues/4148 is resolved.
825825
if (num_invocations > 1
826-
&& !detail::all_iterators_support_plus_operator(::cuda::std::int64_t{}, d_out, d_begin_offsets, d_end_offsets))
826+
&& !detail::all_iterators_support_add_assign_operator(
827+
::cuda::std::int64_t{}, d_out, d_begin_offsets, d_end_offsets))
827828
{
828829
return cudaErrorInvalidValue;
829830
}
@@ -848,13 +849,7 @@ struct DispatchSegmentedReduce
848849
// Invoke DeviceReduceKernel
849850
launcher_factory(
850851
static_cast<::cuda::std::uint32_t>(num_current_segments), policy.SegmentedReduce().BlockThreads(), 0, stream)
851-
.doit(segmented_reduce_kernel,
852-
d_in,
853-
detail::advance_iterators_if_supported(d_out, current_seg_offset),
854-
detail::advance_iterators_if_supported(d_begin_offsets, current_seg_offset),
855-
detail::advance_iterators_if_supported(d_end_offsets, current_seg_offset),
856-
reduction_op,
857-
init);
852+
.doit(segmented_reduce_kernel, d_in, d_out, d_begin_offsets, d_end_offsets, reduction_op, init);
858853

859854
// Check for failure to launch
860855
error = CubDebug(cudaPeekAtLastError());
@@ -863,6 +858,13 @@ struct DispatchSegmentedReduce
863858
break;
864859
}
865860

861+
if (invocation_index + 1 < num_invocations)
862+
{
863+
detail::advance_iterators_inplace_if_supported(d_out, num_current_segments);
864+
detail::advance_iterators_inplace_if_supported(d_begin_offsets, num_current_segments);
865+
detail::advance_iterators_inplace_if_supported(d_end_offsets, num_current_segments);
866+
}
867+
866868
// Sync the stream if specified to flush runtime errors
867869
error = CubDebug(detail::DebugSyncStream(stream));
868870
if (cudaSuccess != error)

0 commit comments

Comments
 (0)