Skip to content

Commit 381c031

Browse files
committed
ggml : q5_0 scalar dot product
1 parent be32443 commit 381c031

File tree

1 file changed

+17
-25
lines changed

1 file changed

+17
-25
lines changed

ggml.c

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3167,18 +3167,16 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
31673167
}
31683168

31693169
static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3170-
GGML_ASSERT(false); // TODO xxxxxxxxx
3171-
31723170
const int nb = n / QK8_1;
31733171

31743172
assert(n % QK8_1 == 0);
31753173
assert(nb % 2 == 0);
3176-
assert(QK8_1 == 2*QK5_0);
3174+
assert(QK8_1 == QK5_0);
31773175

31783176
const block_q5_0 * restrict x = vx;
31793177
const block_q8_1 * restrict y = vy;
31803178

3181-
#if defined(__ARM_NEON)
3179+
#if defined(__ARM_NEON_XXX)
31823180
float32x4_t sumv0 = vdupq_n_f32(0.0f);
31833181
float32x4_t sumv1 = vdupq_n_f32(0.0f);
31843182

@@ -3257,43 +3255,37 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void *
32573255

32583256
*s = hsum_float_8(acc) + summs;
32593257
#else
3260-
// scalar
32613258
float sumf = 0.0;
3259+
32623260
for (int i = 0; i < nb; i++) {
3263-
const uint8_t * restrict x0 = x[2*i + 0].qs;
3264-
const uint8_t * restrict x1 = x[2*i + 1].qs;
3261+
const uint8_t * restrict x0 = x[i].qs;
32653262
const int8_t * restrict y0 = y[i].qs;
32663263

3267-
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3268-
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3269-
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3270-
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3264+
const uint32_t qh = x[i].qh;
32713265

3272-
int sxy_0 = 0;
3273-
int sxy_1 = 0;
3266+
const float d = GGML_FP16_TO_FP32(x[i].d);
3267+
const float m = GGML_FP16_TO_FP32(x[i].m);
32743268

3275-
for (int j = 0; j < QK8_1/4; j++) {
3269+
int sxy = 0;
3270+
3271+
for (int j = 0; j < QK8_1/2; j++) {
32763272
const uint8_t v0 = x0[j];
3277-
const uint8_t v1 = x1[j];
32783273

3279-
const int x0_0 = v0 & 0x0F;
3280-
const int x1_0 = v0 >> 4;
3274+
const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3275+
const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
32813276

3282-
const int x0_1 = v1 & 0x0F;
3283-
const int x1_1 = v1 >> 4;
3277+
const int x0_0 = (v0 & 0x0F) | x0_0h;
3278+
const int x1_0 = (v0 >> 4) | x1_0h;
32843279

32853280
const int y0_0 = y0[2*j + 0];
32863281
const int y1_0 = y0[2*j + 1];
32873282

3288-
const int y0_1 = y0[2*(j + QK8_1/4) + 0];
3289-
const int y1_1 = y0[2*(j + QK8_1/4) + 1];
3290-
3291-
sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3292-
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3283+
sxy += x0_0*y0_0 + x1_0*y1_0;
32933284
}
32943285

3295-
sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
3286+
sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
32963287
}
3288+
32973289
*s = sumf;
32983290
#endif
32993291
}

0 commit comments

Comments
 (0)