@@ -3176,57 +3176,79 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void *
3176
3176
const block_q5_0 * restrict x = vx ;
3177
3177
const block_q8_1 * restrict y = vy ;
3178
3178
3179
- #if defined(__ARM_NEON_XXX )
3180
- float32x4_t sumv0 = vdupq_n_f32 (0.0f );
3181
- float32x4_t sumv1 = vdupq_n_f32 (0.0f );
3179
+ #if defined(__ARM_NEON )
3180
+ float32x4_t sumv = vdupq_n_f32 (0.0f );
3182
3181
3183
- float summs0 = 0.0f ;
3184
- float summs1 = 0.0f ;
3182
+ float summs = 0.0f ;
3183
+
3184
+ uint32_t tmp [8 ];
3185
+
3186
+ static const uint32_t k_mask [16 ] = {
3187
+ 0x00000000 , 0x00000010 , 0x00001000 , 0x00001010 ,
3188
+ 0x00100000 , 0x00100010 , 0x00101000 , 0x00101010 ,
3189
+ 0x10000000 , 0x10000010 , 0x10001000 , 0x10001010 ,
3190
+ 0x10100000 , 0x10100010 , 0x10101000 , 0x10101010 ,
3191
+ };
3185
3192
3186
3193
for (int i = 0 ; i < nb ; ++ i ) {
3187
- const block_q5_0 * restrict x0_0 = & x [2 * ( i + 0 ) + 0 ];
3188
- const block_q5_0 * restrict x0_1 = & x [ 2 * ( i + 0 ) + 1 ];
3194
+ const block_q5_0 * restrict x0 = & x [i ];
3195
+ const block_q8_1 * restrict y0 = & y [ i ];
3189
3196
3190
- const block_q8_1 * restrict y0 = & y [ i + 0 ] ;
3197
+ summs += GGML_FP16_TO_FP32 ( x0 -> m ) * ( y0 -> s0 + y0 -> s1 ) ;
3191
3198
3192
- summs0 += GGML_FP16_TO_FP32 ( x0_0 -> m ) * y0 -> s0 ;
3193
- summs1 += GGML_FP16_TO_FP32 ( x0_1 -> m ) * y0 -> s1 ;
3199
+ // extract the 5th bit
3200
+ const uint32_t qh = x0 -> qh ;
3194
3201
3195
- const uint8x16_t v0_0 = vcombine_u8 (vld1_u8 (x0_0 -> qs ), vld1_u8 (x0_1 -> qs ));
3202
+ tmp [0 ] = k_mask [(qh >> 0 ) & 0x0F ];
3203
+ tmp [1 ] = k_mask [(qh >> 4 ) & 0x0F ];
3204
+ tmp [2 ] = k_mask [(qh >> 8 ) & 0x0F ];
3205
+ tmp [3 ] = k_mask [(qh >> 12 ) & 0x0F ];
3206
+ tmp [4 ] = k_mask [(qh >> 16 ) & 0x0F ];
3207
+ tmp [5 ] = k_mask [(qh >> 20 ) & 0x0F ];
3208
+ tmp [6 ] = k_mask [(qh >> 24 ) & 0x0F ];
3209
+ tmp [7 ] = k_mask [(qh >> 28 )];
3210
+
3211
+ const int8x16_t qhl = vld1q_s8 ((const int8_t * )(tmp + 0 ));
3212
+ const int8x16_t qhh = vld1q_s8 ((const int8_t * )(tmp + 4 ));
3213
+
3214
+ const uint8x16_t v0 = vld1q_u8 (x0 -> qs );
3196
3215
3197
3216
// 4-bit -> 8-bit
3198
- const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , vdupq_n_u8 (0x0F )));
3199
- const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
3217
+ const int8x16_t v0l = vreinterpretq_s8_u8 (vandq_u8 (v0 , vdupq_n_u8 (0x0F )));
3218
+ const int8x16_t v0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0 , 4 ));
3200
3219
3201
3220
// interleave
3202
- const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
3203
- const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
3221
+ const int8x16_t v0lz = vzip1q_s8 (v0l , v0h );
3222
+ const int8x16_t v0hz = vzip2q_s8 (v0l , v0h );
3223
+
3224
+ // add
3225
+ const int8x16_t v0lf = vorrq_s8 (v0lz , qhl );
3226
+ const int8x16_t v0hf = vorrq_s8 (v0hz , qhh );
3204
3227
3205
3228
// load y
3206
- const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
3207
- const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
3229
+ const int8x16_t v1l = vld1q_s8 (y0 -> qs );
3230
+ const int8x16_t v1h = vld1q_s8 (y0 -> qs + 16 );
3208
3231
3209
- const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
3210
- const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
3232
+ const float x0d = GGML_FP16_TO_FP32 (x0 -> d );
3211
3233
3212
3234
#if defined(__ARM_FEATURE_DOTPROD )
3213
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), x0_0d * y0 -> d );
3214
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
3235
+ sumv = vmlaq_n_f32 (sumv , vcvtq_f32_s32 (vaddq_s32 (
3236
+ vdotq_s32 (vdupq_n_s32 (0 ), v0lf , v1l ),
3237
+ vdotq_s32 (vdupq_n_s32 (0 ), v0hf , v1h ))), x0d * y0 -> d );
3215
3238
#else
3216
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
3217
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
3218
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
3219
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
3239
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0lf ), vget_low_s8 (v1l ));
3240
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0lf ), vget_high_s8 (v1l ));
3241
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0hf ), vget_low_s8 (v1h ));
3242
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0hf ), vget_high_s8 (v1h ));
3220
3243
3221
3244
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
3222
3245
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
3223
3246
3224
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl0 ), x0_0d * y0 -> d );
3225
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
3247
+ sumv = vmlaq_n_f32 (sumv , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), x0d * y0 -> d );
3226
3248
#endif
3227
3249
}
3228
3250
3229
- * s = vaddvq_f32 (vaddq_f32 ( sumv0 , sumv1 )) + summs0 + summs1 ;
3251
+ * s = vaddvq_f32 (sumv ) + summs ;
3230
3252
#elif defined(__AVX2__ )
3231
3253
// Initialize accumulator with zeros
3232
3254
__m256 acc = _mm256_setzero_ps ();
0 commit comments