Skip to content

Commit 93501cf

Browse files
Cherry-pick: Optimize block_reduce_warp_reduce when block size is the same as warp size (#599)
* Optimize block_reduce_warp_reduce when block size == warp size * Make conditional constexpr
1 parent eab1eed commit 93501cf

File tree

2 files changed

+43
-30
lines changed

2 files changed

+43
-30
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
Documentation for rocPRIM is available at
44
[https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/).
55

6-
## Unreleased rocPRIM-3.2.0 for ROCm 6.2.0
6+
## rocPRIM-3.2.1 for ROCm 6.2.1
7+
8+
### Optimizations
9+
* Improved performance of block_reduce_warp_reduce when warp size == block size.
10+
11+
## rocPRIM-3.2.0 for ROCm 6.2.0
712

813
### Additions
914

rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,25 @@ class block_reduce_warp_reduce
178178
input, output, num_valid, reduce_op
179179
);
180180

181-
// i-th warp will have its partial stored in storage_.warp_partials[i-1]
182-
if(lane_id == 0)
181+
// Final reduction across warps is only required if there is more than 1 warp
182+
if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1)
183183
{
184-
storage_.warp_partials[warp_id] = output;
185-
}
186-
::rocprim::syncthreads();
187-
188-
if(flat_tid < warps_no_)
189-
{
190-
// Use warp partial to calculate the final reduce results for every thread
191-
auto warp_partial = storage_.warp_partials[lane_id];
192-
193-
warp_reduce<!warps_no_is_pow_of_two_, warp_reduce_output_type>(
194-
warp_partial, output, warps_no_, reduce_op
195-
);
184+
// i-th warp will have its partial stored in storage_.warp_partials[i-1]
185+
if(lane_id == 0)
186+
{
187+
storage_.warp_partials[warp_id] = output;
188+
}
189+
::rocprim::syncthreads();
190+
191+
if(flat_tid < warps_no_)
192+
{
193+
// Use warp partial to calculate the final reduce results for every thread
194+
auto warp_partial = storage_.warp_partials[lane_id];
195+
196+
warp_reduce<!warps_no_is_pow_of_two_, warp_reduce_output_type>(
197+
warp_partial, output, warps_no_, reduce_op
198+
);
199+
}
196200
}
197201
}
198202

@@ -244,22 +248,26 @@ class block_reduce_warp_reduce
244248
input, output, num_valid, reduce_op
245249
);
246250

247-
// i-th warp will have its partial stored in storage_.warp_partials[i-1]
248-
if(lane_id == 0)
251+
// Final reduction across warps is only required if there is more than 1 warp
252+
if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1)
249253
{
250-
storage_.warp_partials[warp_id] = output;
251-
}
252-
::rocprim::syncthreads();
253-
254-
if(flat_tid < warps_no_)
255-
{
256-
// Use warp partial to calculate the final reduce results for every thread
257-
auto warp_partial = storage_.warp_partials[lane_id];
258-
259-
unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_;
260-
warp_reduce_output_type().reduce(
261-
warp_partial, output, valid_warps_no, reduce_op
262-
);
254+
// i-th warp will have its partial stored in storage_.warp_partials[i-1]
255+
if(lane_id == 0)
256+
{
257+
storage_.warp_partials[warp_id] = output;
258+
}
259+
::rocprim::syncthreads();
260+
261+
if(flat_tid < warps_no_)
262+
{
263+
// Use warp partial to calculate the final reduce results for every thread
264+
auto warp_partial = storage_.warp_partials[lane_id];
265+
266+
unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_;
267+
warp_reduce_output_type().reduce(
268+
warp_partial, output, valid_warps_no, reduce_op
269+
);
270+
}
263271
}
264272
}
265273
};

0 commit comments

Comments
 (0)