Skip to content

Commit ac5f860

Browse files
committed
use shfl_xor_sync to reduce redundant shfl broadcast
1 parent 90a5b18 commit ac5f860

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

custom_ops/gpu_ops/per_token_quant_fp8.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,11 @@ __global__ void quant_per_token_per_block(const T *input,
5050
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
5151
}
5252
// get max value per warp
53-
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread);
54-
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread);
55-
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread);
56-
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread);
57-
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread);
58-
// broadcast max_value
59-
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
53+
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread);
54+
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread);
55+
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread);
56+
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread);
57+
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread);
6058
max_value_thread = max(max_value_thread, epsilon);
6159
float scale_to_store = max_value_thread / MAX_VALUE;
6260
// quant

0 commit comments

Comments
 (0)