Skip to content

Commit b7e7046

Browse files
unboundedggerganov
authored andcommitted
Update quantize_row_q4_0 for WASM
Untested
1 parent 5d5f2b2 commit b7e7046

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

ggml.c

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -949,24 +949,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
949949
}
950950
#elif defined(__wasm_simd128__)
951951
for (int i = 0; i < nb; i++) {
952-
float amax = 0.0f; // absolute max
952+
float max = 0.0f;
953+
float min = 0.0f;
953954

954955
v128_t srcv [8];
955-
v128_t asrcv[8];
956-
v128_t amaxv[8];
956+
v128_t maxv[8];
957+
v128_t minv[8];
957958

958959
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
959-
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
960960

961-
for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
962-
for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
963-
for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
961+
for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
962+
for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
963+
for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
964964

965-
amax = MAX(
966-
MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
967-
MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
965+
for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
966+
for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
967+
for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
968968

969-
const float d = amax / ((1 << 3) - 1);
969+
max = MAX(
970+
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
971+
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
972+
min = MIN(
973+
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
974+
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
975+
976+
const float magnitude = max >= fabsf(min) ? max : min;
977+
const float d = magnitude / -8;
970978
const float id = d ? 1.0/d : 0.0;
971979

972980
y[i].d = d;
@@ -975,9 +983,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
975983
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
976984
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
977985
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
986+
const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
978987

979-
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
980-
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
988+
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
989+
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
981990
}
982991
}
983992
#else

0 commit comments

Comments
 (0)