Skip to content

Commit 783921d

Browse files
authored
[Perf] Optimize Vectorization Utils for Int 8 Quantization Kernels (#20331)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 4a98edf commit 783921d

File tree

2 files changed

+106
-7
lines changed

2 files changed

+106
-7
lines changed

csrc/quantization/compressed_tensors/int8_quant_kernels.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
162162

163163
// calculate for absmax
164164
float thread_max = 0.f;
165-
for (int i = tid; i < hidden_size; i += stride) {
166-
const auto v = fabsf(static_cast<float>(row_in[i]));
167-
thread_max = fmaxf(thread_max, v);
168-
}
165+
vectorize_read_with_alignment<16>(
166+
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
167+
const float v = fabsf(static_cast<float>(src));
168+
thread_max = fmaxf(thread_max, v);
169+
});
169170
using BlockReduce = cub::BlockReduce<float, 256>;
170171
__shared__ typename BlockReduce::TempStorage tmp;
171172
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
@@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
232233

233234
// 1. calculate min & max
234235
MinMax thread_mm;
235-
for (int i = tid; i < hidden_size; i += stride) {
236-
thread_mm += static_cast<float>(row_in[i]);
237-
}
236+
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
237+
[&] __device__(const scalar_t& src) {
238+
thread_mm += static_cast<float>(src);
239+
});
238240

239241
using BlockReduce = cub::BlockReduce<MinMax, 256>;
240242
__shared__ typename BlockReduce::TempStorage tmp;

csrc/quantization/vectorization_utils.cuh

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
2727
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
2828
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
2929

30+
// fast path when the whole region is already aligned
31+
// Note: currently the output is guaranteed to be same as the input, so we
32+
// don't check it here, comments here just for future reference.
33+
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
34+
if (can_vec) {
35+
int num_vec = len / VEC_SIZE;
36+
37+
using vin_t = vec_n_t<InT, VEC_SIZE>;
38+
using vout_t = vec_n_t<OutT, VEC_SIZE>;
39+
auto* v_in = reinterpret_cast<const vin_t*>(in);
40+
auto* v_out = reinterpret_cast<vout_t*>(out);
41+
42+
for (int i = tid; i < num_vec; i += stride) {
43+
vout_t tmp;
44+
vec_op(tmp, v_in[i]);
45+
v_out[i] = tmp;
46+
}
47+
return;
48+
}
49+
3050
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
3151
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
3252
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
@@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
7292
std::forward<ScaOp>(scalar_op));
7393
}
7494

95+
template <int VEC_SIZE, typename InT, typename ScaOp>
96+
struct DefaultReadVecOp {
97+
ScaOp scalar_op;
98+
99+
__device__ __forceinline__ void operator()(
100+
const vec_n_t<InT, VEC_SIZE>& src) const {
101+
#pragma unroll
102+
for (int i = 0; i < VEC_SIZE; ++i) {
103+
scalar_op(src.val[i]);
104+
}
105+
}
106+
};
107+
108+
// read-only version: iterate over the input with alignment guarantees
109+
template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
110+
__device__ inline void vectorize_read_with_alignment(const InT* in, int len,
111+
int tid, int stride,
112+
VecOp&& vec_op,
113+
ScaOp&& scalar_op) {
114+
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
115+
"VEC_SIZE must be a positive power-of-two");
116+
constexpr int WIDTH = VEC_SIZE * sizeof(InT);
117+
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
118+
119+
// fast path when the whole region is already aligned
120+
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
121+
if (can_vec) {
122+
int num_vec = len / VEC_SIZE;
123+
124+
using vin_t = vec_n_t<InT, VEC_SIZE>;
125+
auto* v_in = reinterpret_cast<const vin_t*>(in);
126+
127+
for (int i = tid; i < num_vec; i += stride) {
128+
vec_op(v_in[i]);
129+
}
130+
return;
131+
}
132+
133+
int misalignment_offset = addr & (WIDTH - 1);
134+
int alignment_bytes = WIDTH - misalignment_offset;
135+
int prefix_elems = alignment_bytes & (WIDTH - 1);
136+
prefix_elems /= sizeof(InT);
137+
prefix_elems = min(prefix_elems, len);
138+
139+
// 1. handle the possibly unaligned prefix with scalar access.
140+
for (int i = tid; i < prefix_elems; i += stride) {
141+
scalar_op(in[i]);
142+
}
143+
144+
in += prefix_elems;
145+
len -= prefix_elems;
146+
147+
int num_vec = len / VEC_SIZE;
148+
using vin_t = vec_n_t<InT, VEC_SIZE>;
149+
auto* v_in = reinterpret_cast<const vin_t*>(in);
150+
151+
// 2. vectorized traversal of the main aligned region.
152+
for (int i = tid; i < num_vec; i += stride) {
153+
vec_op(v_in[i]);
154+
}
155+
156+
// 3. handle remaining tail elements.
157+
int tail_start = num_vec * VEC_SIZE;
158+
for (int i = tid + tail_start; i < len; i += stride) {
159+
scalar_op(in[i]);
160+
}
161+
}
162+
163+
// overload that requires only a scalar_op
164+
template <int VEC_SIZE, typename InT, typename ScaOp>
165+
__device__ __forceinline__ void vectorize_read_with_alignment(
166+
const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
167+
using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t<ScaOp>>;
168+
vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
169+
std::forward<ScaOp>(scalar_op));
170+
}
171+
75172
} // namespace vllm

0 commit comments

Comments
 (0)