@@ -759,19 +759,24 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
759
759
#elif __ARM_NEON
760
760
for (int i = 0 ; i < nb ; i ++ ) {
761
761
float32x4_t srcv [8 ];
762
- float32x4_t asrcv [8 ];
763
- float32x4_t amaxv [8 ];
762
+ float32x4_t maxv [8 ];
763
+ float32x4_t minv [8 ];
764
764
765
765
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
766
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
767
766
768
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
769
- for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
770
- for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
767
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vmaxq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
768
+ for (int l = 0 ; l < 2 ; l ++ ) maxv [4 * l ] = vmaxq_f32 (maxv [4 * l ], maxv [4 * l + 2 ]);
769
+ for (int l = 0 ; l < 1 ; l ++ ) maxv [8 * l ] = vmaxq_f32 (maxv [8 * l ], maxv [8 * l + 4 ]);
771
770
772
- const float amax = vmaxvq_f32 (amaxv [0 ]);
771
+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vminq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
772
+ for (int l = 0 ; l < 2 ; l ++ ) minv [4 * l ] = vminq_f32 (minv [4 * l ], minv [4 * l + 2 ]);
773
+ for (int l = 0 ; l < 1 ; l ++ ) minv [8 * l ] = vminq_f32 (minv [8 * l ], minv [8 * l + 4 ]);
773
774
774
- const float d = amax / ((1 << 3 ) - 1 );
775
+ const float max = vmaxvq_f32 (maxv [0 ]);
776
+ const float min = vminvq_f32 (minv [0 ]);
777
+
778
+ const float magnitude = max >= fabsf (min ) ? max : min ;
779
+ const float d = magnitude / -8 ;
775
780
const float id = d ? 1.0f /d : 0.0f ;
776
781
777
782
y [i ].d = d ;
@@ -780,9 +785,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
780
785
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
781
786
const float32x4_t vf = vaddq_f32 (v , vdupq_n_f32 (8.5f ));
782
787
const int32x4_t vi = vcvtq_s32_f32 (vf );
788
+ const int32x4 vc = vminq_u32 (vi , vdupq_n_u32 (15 ));
783
789
784
- y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vi , 0 ) | (vgetq_lane_s32 (vi , 1 ) << 4 );
785
- y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vi , 2 ) | (vgetq_lane_s32 (vi , 3 ) << 4 );
790
+ y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vc , 0 ) | (vgetq_lane_s32 (vc , 1 ) << 4 );
791
+ y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vc , 2 ) | (vgetq_lane_s32 (vc , 3 ) << 4 );
786
792
}
787
793
}
788
794
#elif defined(__AVX2__ )
0 commit comments