@@ -178,21 +178,25 @@ class block_reduce_warp_reduce
178
178
input, output, num_valid, reduce_op
179
179
);
180
180
181
- // i-th warp will have its partial stored in storage_.warp_partials[i-1]
182
- if (lane_id == 0 )
181
+ // Final reduction across warps is only required if there is more than 1 warp
182
+ if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1 )
183
183
{
184
- storage_.warp_partials [warp_id] = output;
185
- }
186
- ::rocprim::syncthreads ();
187
-
188
- if (flat_tid < warps_no_)
189
- {
190
- // Use warp partial to calculate the final reduce results for every thread
191
- auto warp_partial = storage_.warp_partials [lane_id];
192
-
193
- warp_reduce<!warps_no_is_pow_of_two_, warp_reduce_output_type>(
194
- warp_partial, output, warps_no_, reduce_op
195
- );
184
+ // i-th warp will have its partial stored in storage_.warp_partials[i-1]
185
+ if (lane_id == 0 )
186
+ {
187
+ storage_.warp_partials [warp_id] = output;
188
+ }
189
+ ::rocprim::syncthreads ();
190
+
191
+ if (flat_tid < warps_no_)
192
+ {
193
+ // Use warp partial to calculate the final reduce results for every thread
194
+ auto warp_partial = storage_.warp_partials [lane_id];
195
+
196
+ warp_reduce<!warps_no_is_pow_of_two_, warp_reduce_output_type>(
197
+ warp_partial, output, warps_no_, reduce_op
198
+ );
199
+ }
196
200
}
197
201
}
198
202
@@ -244,22 +248,26 @@ class block_reduce_warp_reduce
244
248
input, output, num_valid, reduce_op
245
249
);
246
250
247
- // i-th warp will have its partial stored in storage_.warp_partials[i-1]
248
- if (lane_id == 0 )
251
+ // Final reduction across warps is only required if there is more than 1 warp
252
+ if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1 )
249
253
{
250
- storage_.warp_partials [warp_id] = output;
251
- }
252
- ::rocprim::syncthreads ();
253
-
254
- if (flat_tid < warps_no_)
255
- {
256
- // Use warp partial to calculate the final reduce results for every thread
257
- auto warp_partial = storage_.warp_partials [lane_id];
258
-
259
- unsigned int valid_warps_no = (valid_items + warp_size_ - 1 ) / warp_size_;
260
- warp_reduce_output_type ().reduce (
261
- warp_partial, output, valid_warps_no, reduce_op
262
- );
254
+ // i-th warp will have its partial stored in storage_.warp_partials[i-1]
255
+ if (lane_id == 0 )
256
+ {
257
+ storage_.warp_partials [warp_id] = output;
258
+ }
259
+ ::rocprim::syncthreads ();
260
+
261
+ if (flat_tid < warps_no_)
262
+ {
263
+ // Use warp partial to calculate the final reduce results for every thread
264
+ auto warp_partial = storage_.warp_partials [lane_id];
265
+
266
+ unsigned int valid_warps_no = (valid_items + warp_size_ - 1 ) / warp_size_;
267
+ warp_reduce_output_type ().reduce (
268
+ warp_partial, output, valid_warps_no, reduce_op
269
+ );
270
+ }
263
271
}
264
272
}
265
273
};
0 commit comments