Skip to content

Commit 93c95fc

Browse files
unboundedggerganov
authored andcommitted
Update quantize_row_q4_0 for Arm NEON
Untested
1 parent b7e7046 commit 93c95fc

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

ggml.c

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -759,19 +759,24 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
759759
#elif __ARM_NEON
760760
for (int i = 0; i < nb; i++) {
761761
float32x4_t srcv [8];
762-
float32x4_t asrcv[8];
763-
float32x4_t amaxv[8];
762+
float32x4_t maxv[8];
763+
float32x4_t minv[8];
764764

765765
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
766-
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
767766

768-
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
769-
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
770-
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
767+
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
768+
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
769+
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
771770

772-
const float amax = vmaxvq_f32(amaxv[0]);
771+
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
772+
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
773+
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
773774

774-
const float d = amax / ((1 << 3) - 1);
775+
const float max = vmaxvq_f32(maxv[0]);
776+
const float min = vminvq_f32(minv[0]);
777+
778+
const float magnitude = max >= fabsf(min) ? max : min;
779+
const float d = magnitude / -8;
775780
const float id = d ? 1.0f/d : 0.0f;
776781

777782
y[i].d = d;
@@ -780,9 +785,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
780785
const float32x4_t v = vmulq_n_f32(srcv[l], id);
781786
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
782787
const int32x4_t vi = vcvtq_s32_f32(vf);
788+
const int32x4 vc = vminq_u32(vi, vdupq_n_u32(15));
783789

784-
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
785-
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
790+
y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
791+
y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
786792
}
787793
}
788794
#elif defined(__AVX2__)

0 commit comments

Comments
 (0)