Skip to content

Commit c7753a9

Browse files
[Hardware][CPU] Vllm int8 quantization enablement for ARM CPU (#14129)
Signed-off-by: nishith-fujitsu <nishith.jaiswal@fujitsu.com>
1 parent 4b9a943 commit c7753a9

File tree

5 files changed

+347
-30
lines changed

5 files changed

+347
-30
lines changed

cmake/cpu_extension.cmake

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,32 @@ else()
165165
endif()
166166

167167
#
168-
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
169-
#
170-
if (AVX512_FOUND AND NOT AVX512_DISABLED)
168+
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
169+
# Flag to enable ACL kernels for AARCH64 platforms
170+
if ( VLLM_BUILD_ACL STREQUAL "ON")
171+
set(USE_ACL ON)
172+
else()
173+
set(USE_ACL OFF)
174+
endif()
175+
176+
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
171177
FetchContent_Declare(
172178
oneDNN
173179
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
174-
GIT_TAG v3.7.1
180+
GIT_TAG v3.8.1
175181
GIT_PROGRESS TRUE
176182
GIT_SHALLOW TRUE
177183
)
178184

185+
if(USE_ACL)
186+
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
187+
if(NOT ARM_COMPUTE_LIBRARY)
188+
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
189+
endif()
190+
set(ONEDNN_AARCH64_USE_ACL "ON")
191+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
192+
endif()
193+
179194
set(ONEDNN_LIBRARY_TYPE "STATIC")
180195
set(ONEDNN_BUILD_DOC "OFF")
181196
set(ONEDNN_BUILD_EXAMPLES "OFF")
@@ -264,6 +279,11 @@ elseif(POWER10_FOUND)
264279
"csrc/cpu/quant.cpp"
265280
${VLLM_EXT_SRC})
266281
endif()
282+
if (ASIMD_FOUND)
283+
set(VLLM_EXT_SRC
284+
"csrc/cpu/quant.cpp"
285+
${VLLM_EXT_SRC})
286+
endif()
267287

268288
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
269289

csrc/cpu/cpu_types_arm.hpp

Lines changed: 264 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ namespace vec_op {
3333
#endif
3434

3535
#define FORCE_INLINE __attribute__((always_inline)) inline
36+
// Number of elements in single ASIMD vector of given Datatype
37+
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
3638

3739
namespace {
3840
template <typename T, T... indexes, typename F>
@@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
8688
}
8789

8890
void save(void* ptr, const int elem_num) const {
89-
int full_blocks = elem_num / 8;
90-
int remainder = elem_num % 8;
91+
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
92+
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
9193

9294
if (full_blocks > 0) {
9395
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
@@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
197199
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
198200

199201
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
202+
void save(void* ptr, const int elem_num) const {
203+
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
204+
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
205+
for (int i = 0; i < full_blocks; i++)
206+
vst1q_bf16(
207+
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
208+
reg.val[i]);
209+
if (remainder > 0) {
210+
bfloat16x8_t temp = reg.val[full_blocks];
211+
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
212+
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
213+
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
214+
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
215+
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
216+
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
217+
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
218+
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
219+
}
220+
};
200221
};
201222

202223
struct BF16Vec32 : public Vec<BF16Vec32> {
@@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
213234
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
214235

215236
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
237+
void save(void* ptr, const int elem_num) const {
238+
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
239+
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
240+
for (int i = 0; i < full_blocks; i++)
241+
vst1q_bf16(
242+
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
243+
reg.val[i]);
244+
if (remainder > 0) {
245+
bfloat16x8_t temp = reg.val[full_blocks];
246+
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
247+
base[0] = vgetq_lane_bf16(temp, 0);
248+
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
249+
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
250+
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
251+
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
252+
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
253+
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
254+
}
255+
};
216256
};
217257
#endif
218258

@@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
372412
}
373413
};
374414

415+
struct INT32Vec16 : public Vec<INT32Vec16> {
416+
constexpr static int VEC_ELEM_NUM = 16;
417+
union AliasReg {
418+
int32x4x4_t reg;
419+
int32_t values[VEC_ELEM_NUM];
420+
};
421+
int32x4x4_t reg;
422+
423+
explicit INT32Vec16(const void* ptr) {
424+
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
425+
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
426+
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
427+
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
428+
}
429+
430+
void save(int32_t* ptr) const {
431+
vst1q_s32(ptr, reg.val[0]);
432+
vst1q_s32(ptr + 4, reg.val[1]);
433+
vst1q_s32(ptr + 8, reg.val[2]);
434+
vst1q_s32(ptr + 12, reg.val[3]);
435+
};
436+
437+
void save(int32_t* ptr, const int elem_num) const {
438+
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
439+
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
440+
441+
for (int i = 0; i < full_blocks; i++)
442+
vst1q_s32(
443+
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
444+
reg.val[i]);
445+
446+
if (remainder > 0) {
447+
int32x4_t temp = reg.val[full_blocks];
448+
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
449+
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
450+
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
451+
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
452+
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
453+
}
454+
}
455+
};
456+
375457
struct FP32Vec16 : public Vec<FP32Vec16> {
376458
constexpr static int VEC_ELEM_NUM = 16;
377459
union AliasReg {
@@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
434516
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
435517
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
436518
};
437-
519+
explicit FP32Vec16(const INT32Vec16& v) {
520+
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
521+
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
522+
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
523+
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
524+
};
438525
FP32Vec16 operator+(const FP32Vec16& b) const {
439526
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
440527
vaddq_f32(reg.val[1], b.reg.val[1]),
@@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
463550
vdivq_f32(reg.val[3], b.reg.val[3])}));
464551
};
465552

553+
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
554+
return FP32Vec16(float32x4x4_t(
555+
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
556+
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
557+
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
558+
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
559+
};
560+
561+
FP32Vec16 max(const FP32Vec16& b) const {
562+
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
563+
vmaxq_f32(b.reg.val[1], reg.val[1]),
564+
vmaxq_f32(b.reg.val[2], reg.val[2]),
565+
vmaxq_f32(b.reg.val[3], reg.val[3])}));
566+
};
567+
568+
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
569+
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
570+
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
571+
float32x4x4_t temp;
572+
573+
for (int i = 0; i < full_blocks; i++)
574+
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
575+
576+
if (remainder > 0) {
577+
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
578+
vgetq_lane_f32(b.reg.val[full_blocks], 0));
579+
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
580+
}
581+
if (remainder > 1) {
582+
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
583+
vgetq_lane_f32(b.reg.val[full_blocks], 1));
584+
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
585+
}
586+
if (remainder > 2) {
587+
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
588+
vgetq_lane_f32(b.reg.val[full_blocks], 2));
589+
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
590+
}
591+
return FP32Vec16(temp);
592+
};
593+
594+
FP32Vec16 min(const FP32Vec16& b) const {
595+
return FP32Vec16(float32x4x4_t({
596+
vminq_f32(b.reg.val[0], reg.val[0]),
597+
vminq_f32(b.reg.val[1], reg.val[1]),
598+
vminq_f32(b.reg.val[2], reg.val[2]),
599+
vminq_f32(b.reg.val[3], reg.val[3]),
600+
}));
601+
};
602+
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
603+
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
604+
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
605+
float32x4x4_t temp;
606+
for (int i = 0; i < full_blocks; i++)
607+
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
608+
609+
if (remainder > 0) {
610+
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
611+
vgetq_lane_f32(b.reg.val[full_blocks], 0));
612+
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
613+
}
614+
if (remainder > 1) {
615+
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
616+
vgetq_lane_f32(b.reg.val[full_blocks], 1));
617+
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
618+
}
619+
if (remainder > 2) {
620+
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
621+
vgetq_lane_f32(b.reg.val[full_blocks], 2));
622+
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
623+
}
624+
625+
return FP32Vec16(temp);
626+
};
627+
FP32Vec16 abs() const {
628+
return FP32Vec16(
629+
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
630+
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
631+
}
466632
float reduce_sum() const {
467633
AliasReg ar;
468634
ar.reg = reg;
@@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
473639
return answer;
474640
};
475641

642+
float reduce_max() const {
643+
AliasReg ar;
644+
ar.reg = reg;
645+
float max_v = std::numeric_limits<float>::lowest();
646+
unroll_loop<int, VEC_ELEM_NUM>(
647+
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
648+
return max_v;
649+
}
650+
651+
float reduce_min() const {
652+
AliasReg ar;
653+
ar.reg = reg;
654+
float min_v = std::numeric_limits<float>::max();
655+
unroll_loop<int, VEC_ELEM_NUM>(
656+
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
657+
return min_v;
658+
}
659+
476660
template <int group_size>
477661
float reduce_sub_sum(int idx) {
478662
static_assert(VEC_ELEM_NUM % group_size == 0);
@@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
493677
vst1q_f32(ptr + 8, reg.val[2]);
494678
vst1q_f32(ptr + 12, reg.val[3]);
495679
};
680+
681+
void save(float* ptr, const int elem_num) const {
682+
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
683+
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
684+
685+
for (int i = 0; i < full_blocks; i++)
686+
vst1q_f32(
687+
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
688+
reg.val[i]);
689+
690+
if (remainder > 0) {
691+
float32x4_t temp = reg.val[full_blocks];
692+
float* base = reinterpret_cast<float32_t*>(ptr) +
693+
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
694+
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
695+
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
696+
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
697+
}
698+
}
699+
};
700+
701+
struct INT8Vec16 : public Vec<INT8Vec16> {
702+
constexpr static int VEC_ELEM_NUM = 16;
703+
union AliasReg {
704+
int8x16_t reg;
705+
int8_t values[VEC_ELEM_NUM];
706+
};
707+
int8x16_t reg;
708+
709+
explicit INT8Vec16(const FP32Vec16& vec) {
710+
// Convert each 128-bit float32 vector to int32
711+
int32x4_t part0 =
712+
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
713+
int32x4_t part1 =
714+
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
715+
int32x4_t part2 =
716+
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
717+
int32x4_t part3 =
718+
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
719+
720+
// Narrow each 32-bit vector to 8 bits and combine
721+
int8x8_t lower =
722+
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
723+
int8x8_t upper =
724+
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
725+
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
726+
}
727+
728+
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
729+
730+
void save(int8_t* ptr, const int elem_num) const {
731+
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
732+
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
733+
734+
for (int i = 0; i < full_blocks; i++)
735+
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
736+
if (remainder > 0) {
737+
int8x16_t temp = reg;
738+
int8_t* base =
739+
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
740+
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
741+
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
742+
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
743+
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
744+
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
745+
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
746+
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
747+
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
748+
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
749+
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
750+
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
751+
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
752+
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
753+
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
754+
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
755+
}
756+
};
496757
};
497758

498759
template <typename T>

0 commit comments

Comments
 (0)