Skip to content

Commit 4262305

Browse files
committed
ggml : optimize ggml_vec_dot_q4_1_q8_0() via vmalq_n_f32
56 ms/token with Q4_1 !
1 parent e9c07f7 commit 4262305

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

ggml.c

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2390,10 +2390,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
23902390

23912391
// TODO: add AVX / WASM SIMD / etc
23922392
#if defined(__ARM_NEON)
2393-
float sum00 = 0.0f;
2394-
float sum01 = 0.0f;
2395-
float sum10 = 0.0f;
2396-
float sum11 = 0.0f;
2393+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
2394+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
23972395

23982396
for (int i = 0; i < nb; i += 2) {
23992397
const block_q4_1 * restrict x0 = &x[i + 0];
@@ -2424,20 +2422,24 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
24242422
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
24252423
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
24262424

2427-
// Note: cannot use vaddvq_s8 because it overflows for 8-bit values
2428-
// TODO: is there a better way to do this?
2429-
sum00 += (x0->m*y0->d)*(vaddvq_s16(vmovl_s8(vget_low_s8(v1_0ls))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_0ls))) +
2430-
vaddvq_s16(vmovl_s8(vget_low_s8(v1_0hs))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_0hs))));
2431-
sum01 += (x1->m*y1->d)*(vaddvq_s16(vmovl_s8(vget_low_s8(v1_1ls))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_1ls))) +
2432-
vaddvq_s16(vmovl_s8(vget_low_s8(v1_1hs))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_1hs))));
2425+
const int16x8_t s0i = vaddq_s16(
2426+
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2427+
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2428+
2429+
const int16x8_t s1i = vaddq_s16(
2430+
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2431+
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2432+
2433+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2434+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
24332435

24342436
#if defined(__ARM_FEATURE_DOTPROD)
24352437
// dot product into int32x4_t
24362438
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
24372439
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
24382440

2439-
sum10 += (x0->d*y0->d)*vaddvq_s32(p_0);
2440-
sum11 += (x1->d*y1->d)*vaddvq_s32(p_1);
2441+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2442+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
24412443
#else
24422444
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
24432445
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
@@ -2449,21 +2451,17 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
24492451
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
24502452
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
24512453

2452-
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2453-
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2454-
2455-
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2456-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2457-
2458-
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2459-
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2454+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2455+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2456+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2457+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
24602458

2461-
sum10 += x0->d*y0->d*vaddvq_s16(p_0);
2462-
sum11 += x1->d*y1->d*vaddvq_s16(p_1);
2459+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2460+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
24632461
#endif
24642462
}
24652463

2466-
sumf = sum00 + sum01 + sum10 + sum11;
2464+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
24672465
#else
24682466
// scalar
24692467
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)