@@ -543,11 +543,14 @@ __device__ __forceinline__ float vec_dot_iq2_ks_q8_1(
543
543
const uint8_t * a8 = (const uint8_t *)&aux32;
544
544
int v1, v2;
545
545
546
- int8_t s8[4 ];
547
- s8[0 ] = ((bq2->scales [2 *(i4/4 )+0 ] & 0xf ) | ((extra >> 4 ) & 0x10 )) - 16 ;
548
- s8[1 ] = ((bq2->scales [2 *(i4/4 )+0 ] >> 4 ) | ((extra >> 5 ) & 0x10 )) - 16 ;
549
- s8[2 ] = ((bq2->scales [2 *(i4/4 )+1 ] & 0xf ) | ((extra >> 6 ) & 0x10 )) - 16 ;
550
- s8[3 ] = ((bq2->scales [2 *(i4/4 )+1 ] >> 4 ) | ((extra >> 7 ) & 0x10 )) - 16 ;
546
+ int32_t scales32;
547
+ const uint16_t * scales16 = (const uint16_t *)bq2->scales ;
548
+ scales32 = __vsub4 ((scales16[i4/4 ] | (scales16[i4/4 ] << 12 )) & 0x0f0f0f0f , 0x10101010 );
549
+ int8_t * s8 = (int8_t *)&scales32;
550
+ s8[0 ] += ((extra >> 4 ) & 0x10 );
551
+ s8[1 ] += ((extra >> 6 ) & 0x10 );
552
+ s8[2 ] += ((extra >> 5 ) & 0x10 );
553
+ s8[3 ] += ((extra >> 7 ) & 0x10 );
551
554
552
555
aux32[0 ] = ((val1 >> 0 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 0 ) & 0x03030303 ); values = all_values + ((extra & 0x01 ) << 8 );
553
556
v1 = int_from_table_4 (a8 + 0 , values);
@@ -557,12 +560,12 @@ __device__ __forceinline__ float vec_dot_iq2_ks_q8_1(
557
560
aux32[0 ] = ((val1 >> 2 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 2 ) & 0x03030303 ); values = all_values + ((extra & 0x02 ) << 7 );
558
561
v1 = int_from_table_4 (a8 + 0 , values);
559
562
v2 = int_from_table_4 (a8 + 4 , values);
560
- int sumi2 = ggml_cuda_dp4a (v2, q8_2[1 ], ggml_cuda_dp4a (v1, q8_2[0 ], 0 )) * s8[1 ];
563
+ int sumi2 = ggml_cuda_dp4a (v2, q8_2[1 ], ggml_cuda_dp4a (v1, q8_2[0 ], 0 )) * s8[2 ];
561
564
562
565
aux32[0 ] = ((val1 >> 4 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 4 ) & 0x03030303 ); values = all_values + ((extra & 0x04 ) << 6 );
563
566
v1 = int_from_table_4 (a8 + 0 , values);
564
567
v2 = int_from_table_4 (a8 + 4 , values);
565
- int sumi3 = ggml_cuda_dp4a (v2, q8_3[1 ], ggml_cuda_dp4a (v1, q8_3[0 ], 0 )) * s8[2 ];
568
+ int sumi3 = ggml_cuda_dp4a (v2, q8_3[1 ], ggml_cuda_dp4a (v1, q8_3[0 ], 0 )) * s8[1 ];
566
569
567
570
aux32[0 ] = ((val1 >> 6 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 6 ) & 0x03030303 ); values = all_values + ((extra & 0x08 ) << 5 );
568
571
v1 = int_from_table_4 (a8 + 0 , values);
0 commit comments