@@ -2267,6 +2267,7 @@ void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
2267
2267
const block_iq1_kt * x8[8 ];
2268
2268
float dkt[8 ];
2269
2269
float ls[8 ], ls_all[64 ];
2270
+ uint16_t all_idx[256 ];
2270
2271
uint32_t idx[8 ];
2271
2272
2272
2273
for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
@@ -2283,6 +2284,24 @@ void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
2283
2284
auto s16 = vmovl_s8 (vqtbl1_s8 (values, vand_u8 (sh, vdup_n_u8 (0xf ))));
2284
2285
vst1q_f32 (ls_all + 8 *k + 0 , vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (s16))));
2285
2286
vst1q_f32 (ls_all + 8 *k + 4 , vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (s16))));
2287
+ auto ql = vld1q_u8_x2 (x8[k][i].ql );
2288
+ auto qh = vld1q_u8 (x8[k][i].qh );
2289
+ auto qhl = vmovl_u8 (vget_low_u8 (qh));
2290
+ auto qhh = vmovl_u8 (vget_high_u8 (qh));
2291
+ uint16x8x4_t idx;
2292
+ idx.val [0 ] = vaddq_u16 (vmovl_u8 (vget_low_u8 (ql.val [0 ])), vandq_u16 (vdupq_n_u16 (0xf00 ), vshlq_n_u16 (qhl, 8 )));
2293
+ idx.val [1 ] = vaddq_u16 (vmovl_u8 (vget_high_u8 (ql.val [0 ])), vandq_u16 (vdupq_n_u16 (0xf00 ), vshlq_n_u16 (qhh, 8 )));
2294
+ idx.val [2 ] = vaddq_u16 (vmovl_u8 (vget_low_u8 (ql.val [1 ])), vandq_u16 (vdupq_n_u16 (0xf00 ), vshlq_n_u16 (qhl, 4 )));
2295
+ idx.val [3 ] = vaddq_u16 (vmovl_u8 (vget_high_u8 (ql.val [1 ])), vandq_u16 (vdupq_n_u16 (0xf00 ), vshlq_n_u16 (qhh, 4 )));
2296
+ for (int k = 0 ; k < 4 ; ++k) idx.val [k] = vaddq_u16 (idx.val [k], vdupq_n_u16 (4096 ));
2297
+ auto sh16 = vandq_u16 (vmovl_u8 (sh), vdupq_n_u16 (0xf0 ));
2298
+ auto sh32l = vandq_u8 (vreinterpretq_u8_u32 (vmulq_u32 (vmovl_u16 (vget_low_u16 (sh16)), vdupq_n_u32 (0x01020408 ))), vdupq_n_u8 (0x80 ));
2299
+ auto sh32h = vandq_u8 (vreinterpretq_u8_u32 (vmulq_u32 (vmovl_u16 (vget_high_u16 (sh16)), vdupq_n_u32 (0x01020408 ))), vdupq_n_u8 (0x80 ));
2300
+ idx.val [0 ] = vaddq_u16 (idx.val [0 ], vshlq_n_u16 (vmovl_u8 (vget_low_u8 (sh32l)), 5 ));
2301
+ idx.val [1 ] = vaddq_u16 (idx.val [1 ], vshlq_n_u16 (vmovl_u8 (vget_high_u8 (sh32l)), 5 ));
2302
+ idx.val [2 ] = vaddq_u16 (idx.val [2 ], vshlq_n_u16 (vmovl_u8 (vget_low_u8 (sh32h)), 5 ));
2303
+ idx.val [3 ] = vaddq_u16 (idx.val [3 ], vshlq_n_u16 (vmovl_u8 (vget_high_u8 (sh32h)), 5 ));
2304
+ vst1q_u16_x4 (all_idx + 32 *k, idx);
2286
2305
}
2287
2306
for (int ib = 0 ; ib < QK_K/32 ; ++ib) {
2288
2307
for (int k = 0 ; k < 8 ; ++k) ls[k] = ls_all[8 *k+ib];
@@ -2291,10 +2310,7 @@ void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
2291
2310
vst1_f16 ((float16_t *)y[ib].d +0 , vcvt_f16_f32 (scales1));
2292
2311
vst1_f16 ((float16_t *)y[ib].d +4 , vcvt_f16_f32 (scales2));
2293
2312
for (int j = 0 ; j < 4 ; ++j) {
2294
- int jj = 4 *ib + j;
2295
- for (int k = 0 ; k < 8 ; ++k) {
2296
- idx[k] = (x8[k][i].ql [jj] | ((x8[k][i].qh [4 *(ib%4 )+j] << (8 - 4 *(ib/4 ))) & 0xf00 ) | ((x8[k][i].sh [ib] << (8 - j)) & 0x1000 )) + 4096 ;
2297
- }
2313
+ for (int k = 0 ; k < 8 ; ++k) idx[k] = all_idx[32 *k + 4 *ib + j];
2298
2314
vst1q_s8_x4 (y[ib].qs +64 *j, trellis.next64 (idx));
2299
2315
}
2300
2316
}
0 commit comments