Skip to content

Commit 0be08a2

Browse files
author
Iwan Kawrakow
committed
iq1_kt: very slightly faster convert/repack to q8_0_r8 on NEON
1 parent 1c59e79 commit 0be08a2

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

ggml/src/iqk/iqk_gemm_ktquants.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,6 +2267,7 @@ void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
22672267
const block_iq1_kt * x8[8];
22682268
float dkt[8];
22692269
float ls[8], ls_all[64];
2270+
uint16_t all_idx[256];
22702271
uint32_t idx[8];
22712272

22722273
for (int ix = 0; ix < nrc_x; ix += 8) {
@@ -2283,6 +2284,24 @@ void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
22832284
auto s16 = vmovl_s8(vqtbl1_s8(values, vand_u8(sh, vdup_n_u8(0xf))));
22842285
vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16))));
22852286
vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
2287+
auto ql = vld1q_u8_x2(x8[k][i].ql);
2288+
auto qh = vld1q_u8(x8[k][i].qh);
2289+
auto qhl = vmovl_u8(vget_low_u8(qh));
2290+
auto qhh = vmovl_u8(vget_high_u8(qh));
2291+
uint16x8x4_t idx;
2292+
idx.val[0] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 8)));
2293+
idx.val[1] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 8)));
2294+
idx.val[2] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 4)));
2295+
idx.val[3] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 4)));
2296+
for (int k = 0; k < 4; ++k) idx.val[k] = vaddq_u16(idx.val[k], vdupq_n_u16(4096));
2297+
auto sh16 = vandq_u16(vmovl_u8(sh), vdupq_n_u16(0xf0));
2298+
auto sh32l = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_low_u16 (sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
2299+
auto sh32h = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_high_u16(sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
2300+
idx.val[0] = vaddq_u16(idx.val[0], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32l)), 5));
2301+
idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5));
2302+
idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5));
2303+
idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5));
2304+
vst1q_u16_x4(all_idx + 32*k, idx);
22862305
}
22872306
for (int ib = 0; ib < QK_K/32; ++ib) {
22882307
for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
@@ -2291,10 +2310,7 @@ void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
22912310
vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
22922311
vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
22932312
for (int j = 0; j < 4; ++j) {
2294-
int jj = 4*ib + j;
2295-
for (int k = 0; k < 8; ++k) {
2296-
idx[k] = (x8[k][i].ql[jj] | ((x8[k][i].qh[4*(ib%4)+j] << (8 - 4*(ib/4))) & 0xf00) | ((x8[k][i].sh[ib] << (8 - j)) & 0x1000)) + 4096;
2297-
}
2313+
for (int k = 0; k < 8; ++k) idx[k] = all_idx[32*k + 4*ib + j];
22982314
vst1q_s8_x4(y[ib].qs+64*j, trellis.next64(idx));
22992315
}
23002316
}

0 commit comments

Comments
 (0)