@@ -2390,10 +2390,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2390
2390
2391
2391
// TODO: add AVX / WASM SIMD / etc
2392
2392
#if defined(__ARM_NEON )
2393
- float sum00 = 0.0f ;
2394
- float sum01 = 0.0f ;
2395
- float sum10 = 0.0f ;
2396
- float sum11 = 0.0f ;
2393
+ float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2394
+ float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2397
2395
2398
2396
for (int i = 0 ; i < nb ; i += 2 ) {
2399
2397
const block_q4_1 * restrict x0 = & x [i + 0 ];
@@ -2424,20 +2422,24 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2424
2422
const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2425
2423
const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2426
2424
2427
- // Note: cannot use vaddvq_s8 because it overflows for 8-bit values
2428
- // TODO: is there a better way to do this?
2429
- sum00 += (x0 -> m * y0 -> d )* (vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_0ls ))) +
2430
- vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_0hs ))));
2431
- sum01 += (x1 -> m * y1 -> d )* (vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_1ls ))) +
2432
- vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_1hs ))));
2425
+ const int16x8_t s0i = vaddq_s16 (
2426
+ vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls )), vmovl_s8 (vget_high_s8 (v1_0ls ))),
2427
+ vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs )), vmovl_s8 (vget_high_s8 (v1_0hs ))));
2428
+
2429
+ const int16x8_t s1i = vaddq_s16 (
2430
+ vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls )), vmovl_s8 (vget_high_s8 (v1_1ls ))),
2431
+ vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs )), vmovl_s8 (vget_high_s8 (v1_1hs ))));
2432
+
2433
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s0i ), vget_high_s16 (s0i ))), x0 -> m * y0 -> d );
2434
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s1i ), vget_high_s16 (s1i ))), x1 -> m * y1 -> d );
2433
2435
2434
2436
#if defined(__ARM_FEATURE_DOTPROD )
2435
2437
// dot product into int32x4_t
2436
2438
const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2437
2439
const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2438
2440
2439
- sum10 += ( x0 -> d * y0 -> d ) * vaddvq_s32 ( p_0 );
2440
- sum11 += ( x1 -> d * y1 -> d ) * vaddvq_s32 ( p_1 );
2441
+ sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( p_0 ), x0 -> d * y0 -> d );
2442
+ sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( p_1 ), x1 -> d * y1 -> d );
2441
2443
#else
2442
2444
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2443
2445
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
@@ -2449,21 +2451,17 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2449
2451
const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2450
2452
const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
2451
2453
2452
- const int16x8_t pl_0 = vaddq_s16 (pl0l , pl0h );
2453
- const int16x8_t ph_0 = vaddq_s16 (ph0l , ph0h );
2454
-
2455
- const int16x8_t pl_1 = vaddq_s16 (pl1l , pl1h );
2456
- const int16x8_t ph_1 = vaddq_s16 (ph1l , ph1h );
2457
-
2458
- const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
2459
- const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
2454
+ const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2455
+ const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2456
+ const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2457
+ const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2460
2458
2461
- sum10 += x0 -> d * y0 -> d * vaddvq_s16 ( p_0 );
2462
- sum11 += x1 -> d * y1 -> d * vaddvq_s16 ( p_1 );
2459
+ sumv0 = vmlaq_n_f32 ( sumv0 , vcvtq_f32_s32 ( vaddq_s32 ( pl0 , ph0 )), x0 -> d * y0 -> d );
2460
+ sumv1 = vmlaq_n_f32 ( sumv1 , vcvtq_f32_s32 ( vaddq_s32 ( pl1 , ph1 )), x1 -> d * y1 -> d );
2463
2461
#endif
2464
2462
}
2465
2463
2466
- sumf = sum00 + sum01 + sum10 + sum11 ;
2464
+ sumf = vaddvq_f32 ( sumv0 ) + vaddvq_f32 ( sumv1 ) ;
2467
2465
#else
2468
2466
// scalar
2469
2467
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments