@@ -3167,18 +3167,16 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
3167
3167
}
3168
3168
3169
3169
static void ggml_vec_dot_q5_0_q8_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
3170
- GGML_ASSERT (false); // TODO xxxxxxxxx
3171
-
3172
3170
const int nb = n / QK8_1 ;
3173
3171
3174
3172
assert (n % QK8_1 == 0 );
3175
3173
assert (nb % 2 == 0 );
3176
- assert (QK8_1 == 2 * QK5_0 );
3174
+ assert (QK8_1 == QK5_0 );
3177
3175
3178
3176
const block_q5_0 * restrict x = vx ;
3179
3177
const block_q8_1 * restrict y = vy ;
3180
3178
3181
- #if defined(__ARM_NEON )
3179
+ #if defined(__ARM_NEON_XXX )
3182
3180
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
3183
3181
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
3184
3182
@@ -3257,43 +3255,37 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void *
3257
3255
3258
3256
* s = hsum_float_8 (acc ) + summs ;
3259
3257
#else
3260
- // scalar
3261
3258
float sumf = 0.0 ;
3259
+
3262
3260
for (int i = 0 ; i < nb ; i ++ ) {
3263
- const uint8_t * restrict x0 = x [2 * i + 0 ].qs ;
3264
- const uint8_t * restrict x1 = x [2 * i + 1 ].qs ;
3261
+ const uint8_t * restrict x0 = x [i ].qs ;
3265
3262
const int8_t * restrict y0 = y [i ].qs ;
3266
3263
3267
- const float d0 = GGML_FP16_TO_FP32 (x [2 * i + 0 ].d );
3268
- const float m0 = GGML_FP16_TO_FP32 (x [2 * i + 0 ].m );
3269
- const float d1 = GGML_FP16_TO_FP32 (x [2 * i + 1 ].d );
3270
- const float m1 = GGML_FP16_TO_FP32 (x [2 * i + 1 ].m );
3264
+ const uint32_t qh = x [i ].qh ;
3271
3265
3272
- int sxy_0 = 0 ;
3273
- int sxy_1 = 0 ;
3266
+ const float d = GGML_FP16_TO_FP32 ( x [ i ]. d ) ;
3267
+ const float m = GGML_FP16_TO_FP32 ( x [ i ]. m ) ;
3274
3268
3275
- for (int j = 0 ; j < QK8_1 /4 ; j ++ ) {
3269
+ int sxy = 0 ;
3270
+
3271
+ for (int j = 0 ; j < QK8_1 /2 ; j ++ ) {
3276
3272
const uint8_t v0 = x0 [j ];
3277
- const uint8_t v1 = x1 [j ];
3278
3273
3279
- const int x0_0 = v0 & 0x0F ;
3280
- const int x1_0 = v0 >> 4 ;
3274
+ const int x0_0h = (( qh & ( 1 << ( 2 * j + 0 ))) >> ( 2 * j + 0 )) << 4 ;
3275
+ const int x1_0h = (( qh & ( 1 << ( 2 * j + 1 ))) >> ( 2 * j + 1 )) << 4 ;
3281
3276
3282
- const int x0_1 = v1 & 0x0F ;
3283
- const int x1_1 = v1 >> 4 ;
3277
+ const int x0_0 = ( v0 & 0x0F ) | x0_0h ;
3278
+ const int x1_0 = ( v0 >> 4 ) | x1_0h ;
3284
3279
3285
3280
const int y0_0 = y0 [2 * j + 0 ];
3286
3281
const int y1_0 = y0 [2 * j + 1 ];
3287
3282
3288
- const int y0_1 = y0 [2 * (j + QK8_1 /4 ) + 0 ];
3289
- const int y1_1 = y0 [2 * (j + QK8_1 /4 ) + 1 ];
3290
-
3291
- sxy_0 += x0_0 * y0_0 + x1_0 * y1_0 ;
3292
- sxy_1 += x0_1 * y0_1 + x1_1 * y1_1 ;
3283
+ sxy += x0_0 * y0_0 + x1_0 * y1_0 ;
3293
3284
}
3294
3285
3295
- sumf += (d0 * sxy_0 + d1 * sxy_1 )* y [i ].d + m0 * y [i ].s0 + m1 * y [i ].s1 ;
3286
+ sumf += (d * sxy )* y [i ].d + m * ( y [i ].s0 + y [i ].s1 ) ;
3296
3287
}
3288
+
3297
3289
* s = sumf ;
3298
3290
#endif
3299
3291
}
0 commit comments