Skip to content

Commit 4bf196e

Browse files
committed
ggml : q5_0 ARM NEON dot
1 parent 381c031 commit 4bf196e

File tree

1 file changed

+50
-28
lines changed

1 file changed

+50
-28
lines changed

ggml.c

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3176,57 +3176,79 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void *
31763176
const block_q5_0 * restrict x = vx;
31773177
const block_q8_1 * restrict y = vy;
31783178

3179-
#if defined(__ARM_NEON_XXX)
3180-
float32x4_t sumv0 = vdupq_n_f32(0.0f);
3181-
float32x4_t sumv1 = vdupq_n_f32(0.0f);
3179+
#if defined(__ARM_NEON)
3180+
float32x4_t sumv = vdupq_n_f32(0.0f);
31823181

3183-
float summs0 = 0.0f;
3184-
float summs1 = 0.0f;
3182+
float summs = 0.0f;
3183+
3184+
uint32_t tmp[8];
3185+
3186+
static const uint32_t k_mask[16] = {
3187+
0x00000000, 0x00000010, 0x00001000, 0x00001010,
3188+
0x00100000, 0x00100010, 0x00101000, 0x00101010,
3189+
0x10000000, 0x10000010, 0x10001000, 0x10001010,
3190+
0x10100000, 0x10100010, 0x10101000, 0x10101010,
3191+
};
31853192

31863193
for (int i = 0; i < nb; ++i) {
3187-
const block_q5_0 * restrict x0_0 = &x[2*(i + 0) + 0];
3188-
const block_q5_0 * restrict x0_1 = &x[2*(i + 0) + 1];
3194+
const block_q5_0 * restrict x0 = &x[i];
3195+
const block_q8_1 * restrict y0 = &y[i];
31893196

3190-
const block_q8_1 * restrict y0 = &y[i + 0];
3197+
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
31913198

3192-
summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
3193-
summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
3199+
// extract the 5th bit
3200+
const uint32_t qh = x0->qh;
31943201

3195-
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3202+
tmp[0] = k_mask[(qh >> 0) & 0x0F];
3203+
tmp[1] = k_mask[(qh >> 4) & 0x0F];
3204+
tmp[2] = k_mask[(qh >> 8) & 0x0F];
3205+
tmp[3] = k_mask[(qh >> 12) & 0x0F];
3206+
tmp[4] = k_mask[(qh >> 16) & 0x0F];
3207+
tmp[5] = k_mask[(qh >> 20) & 0x0F];
3208+
tmp[6] = k_mask[(qh >> 24) & 0x0F];
3209+
tmp[7] = k_mask[(qh >> 28)];
3210+
3211+
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3212+
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 4));
3213+
3214+
const uint8x16_t v0 = vld1q_u8(x0->qs);
31963215

31973216
// 4-bit -> 8-bit
3198-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
3199-
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3217+
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
3218+
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
32003219

32013220
// interleave
3202-
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
3203-
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
3221+
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3222+
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3223+
3224+
// add
3225+
const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
3226+
const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
32043227

32053228
// load y
3206-
const int8x16_t v1_0l = vld1q_s8(y0->qs);
3207-
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3229+
const int8x16_t v1l = vld1q_s8(y0->qs);
3230+
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
32083231

3209-
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
3210-
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
3232+
const float x0d = GGML_FP16_TO_FP32(x0->d);
32113233

32123234
#if defined(__ARM_FEATURE_DOTPROD)
3213-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
3214-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
3235+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3236+
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3237+
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
32153238
#else
3216-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3217-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3218-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3219-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3239+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3240+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3241+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3242+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
32203243

32213244
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
32223245
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
32233246

3224-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
3225-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
3247+
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
32263248
#endif
32273249
}
32283250

3229-
*s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
3251+
*s = vaddvq_f32(sumv) + summs;
32303252
#elif defined(__AVX2__)
32313253
// Initialize accumulator with zeros
32323254
__m256 acc = _mm256_setzero_ps();

0 commit comments

Comments
 (0)