@@ -771,19 +771,24 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
771
771
#elif __ARM_NEON
772
772
for (int i = 0 ; i < nb ; i ++ ) {
773
773
float32x4_t srcv [8 ];
774
- float32x4_t asrcv [8 ];
775
- float32x4_t amaxv [8 ];
774
+ float32x4_t maxv [8 ];
775
+ float32x4_t minv [8 ];
776
776
777
777
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
778
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
779
778
780
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
781
- for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
782
- for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
779
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vmaxq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
780
+ for (int l = 0 ; l < 2 ; l ++ ) maxv [4 * l ] = vmaxq_f32 (maxv [4 * l ], maxv [4 * l + 2 ]);
781
+ for (int l = 0 ; l < 1 ; l ++ ) maxv [8 * l ] = vmaxq_f32 (maxv [8 * l ], maxv [8 * l + 4 ]);
783
782
784
- const float amax = vmaxvq_f32 (amaxv [0 ]);
783
+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vminq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
784
+ for (int l = 0 ; l < 2 ; l ++ ) minv [4 * l ] = vminq_f32 (minv [4 * l ], minv [4 * l + 2 ]);
785
+ for (int l = 0 ; l < 1 ; l ++ ) minv [8 * l ] = vminq_f32 (minv [8 * l ], minv [8 * l + 4 ]);
785
786
786
- const float d = amax / ((1 << 3 ) - 1 );
787
+ const float max = vmaxvq_f32 (maxv [0 ]);
788
+ const float min = vminvq_f32 (minv [0 ]);
789
+
790
+ const float magnitude = max >= fabsf (min ) ? max : min ;
791
+ const float d = magnitude / -8 ;
787
792
const float id = d ? 1.0f /d : 0.0f ;
788
793
789
794
y [i ].d = d ;
@@ -792,9 +797,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
792
797
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
793
798
const float32x4_t vf = vaddq_f32 (v , vdupq_n_f32 (8.5f ));
794
799
const int32x4_t vi = vcvtq_s32_f32 (vf );
800
+ const int32x4_t vc = vminq_s32 (vi , vdupq_n_s32 (15 ));
795
801
796
- y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vi , 0 ) | (vgetq_lane_s32 (vi , 1 ) << 4 );
797
- y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vi , 2 ) | (vgetq_lane_s32 (vi , 3 ) << 4 );
802
+ y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vc , 0 ) | (vgetq_lane_s32 (vc , 1 ) << 4 );
803
+ y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vc , 2 ) | (vgetq_lane_s32 (vc , 3 ) << 4 );
798
804
}
799
805
}
800
806
#elif defined(__AVX2__ )
0 commit comments