@@ -50,13 +50,11 @@ __global__ void quant_per_token_per_block(const T *input,
50
50
max_value_thread = max (abs (load_vec_float[vid]), max_value_thread);
51
51
}
52
52
// get max value per warp
53
- max_value_thread = max (__shfl_down_sync (0xffffffff , max_value_thread, 16 ), max_value_thread);
54
- max_value_thread = max (__shfl_down_sync (0xffffffff , max_value_thread, 8 ), max_value_thread);
55
- max_value_thread = max (__shfl_down_sync (0xffffffff , max_value_thread, 4 ), max_value_thread);
56
- max_value_thread = max (__shfl_down_sync (0xffffffff , max_value_thread, 2 ), max_value_thread);
57
- max_value_thread = max (__shfl_down_sync (0xffffffff , max_value_thread, 1 ), max_value_thread);
58
- // broadcast max_value
59
- max_value_thread = __shfl_sync (0xFFFFFFFF , max_value_thread, 0 );
53
+ max_value_thread = max (__shfl_xor_sync (0xffffffff , max_value_thread, 16 ), max_value_thread);
54
+ max_value_thread = max (__shfl_xor_sync (0xffffffff , max_value_thread, 8 ), max_value_thread);
55
+ max_value_thread = max (__shfl_xor_sync (0xffffffff , max_value_thread, 4 ), max_value_thread);
56
+ max_value_thread = max (__shfl_xor_sync (0xffffffff , max_value_thread, 2 ), max_value_thread);
57
+ max_value_thread = max (__shfl_xor_sync (0xffffffff , max_value_thread, 1 ), max_value_thread);
60
58
max_value_thread = max (max_value_thread, epsilon);
61
59
float scale_to_store = max_value_thread / MAX_VALUE;
62
60
// quant
0 commit comments