Skip to content

Commit fea8d10

Browse files
committed
Update quantize_row_q4_0 for Arm NEON
1 parent 73a92d2 commit fea8d10

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

ggml.c

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -771,19 +771,24 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
771771
#elif __ARM_NEON
772772
for (int i = 0; i < nb; i++) {
773773
float32x4_t srcv [8];
774-
float32x4_t asrcv[8];
775-
float32x4_t amaxv[8];
774+
float32x4_t maxv[8];
775+
float32x4_t minv[8];
776776

777777
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
778-
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
779778

780-
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
781-
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
782-
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
779+
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
780+
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
781+
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
783782

784-
const float amax = vmaxvq_f32(amaxv[0]);
783+
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
784+
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
785+
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
785786

786-
const float d = amax / ((1 << 3) - 1);
787+
const float max = vmaxvq_f32(maxv[0]);
788+
const float min = vminvq_f32(minv[0]);
789+
790+
const float magnitude = max >= fabsf(min) ? max : min;
791+
const float d = magnitude / -8;
787792
const float id = d ? 1.0f/d : 0.0f;
788793

789794
y[i].d = d;
@@ -792,9 +797,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
792797
const float32x4_t v = vmulq_n_f32(srcv[l], id);
793798
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
794799
const int32x4_t vi = vcvtq_s32_f32(vf);
800+
const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
795801

796-
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
797-
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
802+
y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
803+
y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
798804
}
799805
}
800806
#elif defined(__AVX2__)

0 commit comments

Comments
 (0)