@@ -2766,8 +2766,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2766
2766
float sumf = 0.0 ;
2767
2767
2768
2768
#if defined(__ARM_NEON )
2769
- float sum0 = 0.0f ;
2770
- float sum1 = 0.0f ;
2769
+ float32x4_t sumv0 = vdupq_n_f32 ( 0.0f ) ;
2770
+ float32x4_t sumv1 = vdupq_n_f32 ( 0.0f ) ;
2771
2771
2772
2772
for (int i = 0 ; i < nb ; i += 2 ) {
2773
2773
const block_q4_0 * restrict x0 = & x [i + 0 ];
@@ -2807,14 +2807,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2807
2807
2808
2808
#if defined(__ARM_FEATURE_DOTPROD )
2809
2809
// dot product into int32x4_t
2810
- int32x4_t p_0 = vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls );
2811
- int32x4_t p_1 = vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls );
2812
-
2813
- p_0 = vdotq_s32 (p_0 , v0_0hs , v1_0hs );
2814
- p_1 = vdotq_s32 (p_1 , v0_1hs , v1_1hs );
2810
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2811
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2815
2812
2816
- sum0 += x0 -> d * y0 -> d * vaddvq_s32 (p_0 );
2817
- sum1 += x1 -> d * y1 -> d * vaddvq_s32 (p_1 );
2813
+ #if 0
2814
+ // note: this is faster for 4-6 threads by slower for more threads
2815
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
2816
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
2817
+ #else
2818
+ sumv0 = vaddq_f32 (sumv0 , vmulq_f32 (vcvtq_f32_s32 (p_0 ), vdupq_n_f32 (x0 -> d * y0 -> d )));
2819
+ sumv1 = vaddq_f32 (sumv1 , vmulq_f32 (vcvtq_f32_s32 (p_1 ), vdupq_n_f32 (x1 -> d * y1 -> d )));
2820
+ #endif
2818
2821
#else
2819
2822
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2820
2823
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
@@ -2826,21 +2829,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2826
2829
const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2827
2830
const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2828
2831
2829
- const int16x8_t pl_0 = vaddq_s16 (pl0l , pl0h );
2830
- const int16x8_t ph_0 = vaddq_s16 (ph0l , ph0h );
2831
-
2832
- const int16x8_t pl_1 = vaddq_s16 (pl1l , pl1h );
2833
- const int16x8_t ph_1 = vaddq_s16 (ph1l , ph1h );
2832
+ const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2833
+ const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2834
+ const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2835
+ const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2834
2836
2835
- const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
2836
- const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
2837
-
2838
- sum0 += x0 -> d * y0 -> d * vaddvq_s16 (p_0 );
2839
- sum1 += x1 -> d * y1 -> d * vaddvq_s16 (p_1 );
2837
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), x0 -> d * y0 -> d );
2838
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (pl1 , ph1 )), x1 -> d * y1 -> d );
2840
2839
#endif
2841
2840
}
2842
2841
2843
- sumf = sum0 + sum1 ;
2842
+ sumf = vaddvq_f32 ( sumv0 ) + vaddvq_f32 ( sumv1 ) ;
2844
2843
#elif defined(__AVX2__ )
2845
2844
// Initialize accumulator with zeros
2846
2845
__m256 acc = _mm256_setzero_ps ();
0 commit comments