Skip to content

Commit dcdd65e

Browse files
committed
ggml : optimize ggml_vec_dot_q4_0_q8_0() using vectorized accumulators
1 parent 5ecff35 commit dcdd65e

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

ggml.c

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2766,8 +2766,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
27662766
float sumf = 0.0;
27672767

27682768
#if defined(__ARM_NEON)
2769-
float sum0 = 0.0f;
2770-
float sum1 = 0.0f;
2769+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
2770+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
27712771

27722772
for (int i = 0; i < nb; i += 2) {
27732773
const block_q4_0 * restrict x0 = &x[i + 0];
@@ -2807,14 +2807,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
28072807

28082808
#if defined(__ARM_FEATURE_DOTPROD)
28092809
// dot product into int32x4_t
2810-
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811-
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2812-
2813-
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814-
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2810+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2811+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
28152812

2816-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2813+
#if 0
2814+
// note: this is faster for 4-6 threads by slower for more threads
2815+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2816+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2817+
#else
2818+
sumv0 = vaddq_f32(sumv0, vmulq_f32(vcvtq_f32_s32(p_0), vdupq_n_f32(x0->d*y0->d)));
2819+
sumv1 = vaddq_f32(sumv1, vmulq_f32(vcvtq_f32_s32(p_1), vdupq_n_f32(x1->d*y1->d)));
2820+
#endif
28182821
#else
28192822
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
28202823
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -2826,21 +2829,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
28262829
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
28272830
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
28282831

2829-
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830-
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831-
2832-
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2832+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2833+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2834+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2835+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
28342836

2835-
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836-
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837-
2838-
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839-
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2837+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2838+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
28402839
#endif
28412840
}
28422841

2843-
sumf = sum0 + sum1;
2842+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
28442843
#elif defined(__AVX2__)
28452844
// Initialize accumulator with zeros
28462845
__m256 acc = _mm256_setzero_ps();

0 commit comments

Comments
 (0)