@@ -438,6 +438,30 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
438
438
return res ;
439
439
}
440
440
441
+ // NOTE: not tested
442
+ inline static int8x16_t ggml_vqtbl1q_s8 (int8x16_t a , uint8x16_t b ) {
443
+ int8x16_t res ;
444
+
445
+ res [ 0 ] = a [b [ 0 ]];
446
+ res [ 1 ] = a [b [ 1 ]];
447
+ res [ 2 ] = a [b [ 2 ]];
448
+ res [ 3 ] = a [b [ 3 ]];
449
+ res [ 4 ] = a [b [ 4 ]];
450
+ res [ 5 ] = a [b [ 5 ]];
451
+ res [ 6 ] = a [b [ 6 ]];
452
+ res [ 7 ] = a [b [ 7 ]];
453
+ res [ 8 ] = a [b [ 8 ]];
454
+ res [ 9 ] = a [b [ 9 ]];
455
+ res [10 ] = a [b [10 ]];
456
+ res [11 ] = a [b [11 ]];
457
+ res [12 ] = a [b [12 ]];
458
+ res [13 ] = a [b [13 ]];
459
+ res [14 ] = a [b [14 ]];
460
+ res [15 ] = a [b [15 ]];
461
+
462
+ return res ;
463
+ }
464
+
441
465
#else
442
466
443
467
#define ggml_int16x8x2_t int16x8x2_t
@@ -451,6 +475,7 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
451
475
#define ggml_vld1q_u8_x4 vld1q_u8_x4
452
476
#define ggml_vld1q_s8_x2 vld1q_s8_x2
453
477
#define ggml_vld1q_s8_x4 vld1q_s8_x4
478
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
454
479
455
480
#endif
456
481
@@ -9333,7 +9358,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
9333
9358
uint16_t gindex [8 ];
9334
9359
uint16x8x2_t vindex ;
9335
9360
int8x16x4_t q1b ;
9336
- int8x16x4_t q8b ;
9361
+ ggml_int8x16x4_t q8b ;
9337
9362
uint16x8x4_t scales ;
9338
9363
int32x4x2_t sumi ;
9339
9364
int32x4x2_t dotq ;
@@ -9506,10 +9531,10 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
9506
9531
q8b .val [2 ] = vld1q_s8 (y [ib + 1 ].qs );
9507
9532
q8b .val [3 ] = vld1q_s8 (y [ib + 1 ].qs + 16 );
9508
9533
9509
- q4b .val [0 ] = vqtbl1q_s8 (values , vandq_u8 (q4bits .val [0 ], m4b ));
9510
- q4b .val [1 ] = vqtbl1q_s8 (values , vshrq_n_u8 (q4bits .val [0 ], 4 ));
9511
- q4b .val [2 ] = vqtbl1q_s8 (values , vandq_u8 (q4bits .val [1 ], m4b ));
9512
- q4b .val [3 ] = vqtbl1q_s8 (values , vshrq_n_u8 (q4bits .val [1 ], 4 ));
9534
+ q4b .val [0 ] = ggml_vqtbl1q_s8 (values , vandq_u8 (q4bits .val [0 ], m4b ));
9535
+ q4b .val [1 ] = ggml_vqtbl1q_s8 (values , vshrq_n_u8 (q4bits .val [0 ], 4 ));
9536
+ q4b .val [2 ] = ggml_vqtbl1q_s8 (values , vandq_u8 (q4bits .val [1 ], m4b ));
9537
+ q4b .val [3 ] = ggml_vqtbl1q_s8 (values , vshrq_n_u8 (q4bits .val [1 ], 4 ));
9513
9538
9514
9539
prod_1 = ggml_vdotq_s32 (ggml_vdotq_s32 (vdupq_n_s32 (0 ), q4b .val [0 ], q8b .val [0 ]), q4b .val [1 ], q8b .val [1 ]);
9515
9540
prod_2 = ggml_vdotq_s32 (ggml_vdotq_s32 (vdupq_n_s32 (0 ), q4b .val [2 ], q8b .val [2 ]), q4b .val [3 ], q8b .val [3 ]);
0 commit comments