5
5
//
6
6
7
7
#include " iqk_mmvq.cuh"
8
+ #include " iqk_cuda_common.h"
8
9
9
10
typedef void (*vec_dot_q_cuda_t )(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float *);
10
11
@@ -785,77 +786,6 @@ __device__ __forceinline__ void vec_dot_iq6_k_q8_1(
785
786
*result += d6 * (__low2float (bq8_1[2 *(i4/2 )+0 ].ds ) * sumi1 * bq6->scales [4 *(i4/2 )+(i4%2 )] + __low2float (bq8_1[2 *(i4/2 )+1 ].ds ) * sumi2 * bq6->scales [4 *(i4/2 )+(i4%2 )+2 ]);
786
787
}
787
788
788
- static const __device__ uint32_t iq2k_table[512 ] = {
789
- 0xe1e1e1e1 , 0xe1e1e1f3 , 0xe1e1e101 , 0xe1e1e111 , 0xe1e1f3e1 , 0xe1e1f3f3 , 0xe1e1f301 , 0xe1e1f311 ,
790
- 0xe1e101e1 , 0xe1e101f3 , 0xe1e10101 , 0xe1e10111 , 0xe1e111e1 , 0xe1e111f3 , 0xe1e11101 , 0xe1e11111 ,
791
- 0xe1f3e1e1 , 0xe1f3e1f3 , 0xe1f3e101 , 0xe1f3e111 , 0xe1f3f3e1 , 0xe1f3f3f3 , 0xe1f3f301 , 0xe1f3f311 ,
792
- 0xe1f301e1 , 0xe1f301f3 , 0xe1f30101 , 0xe1f30111 , 0xe1f311e1 , 0xe1f311f3 , 0xe1f31101 , 0xe1f31111 ,
793
- 0xe101e1e1 , 0xe101e1f3 , 0xe101e101 , 0xe101e111 , 0xe101f3e1 , 0xe101f3f3 , 0xe101f301 , 0xe101f311 ,
794
- 0xe10101e1 , 0xe10101f3 , 0xe1010101 , 0xe1010111 , 0xe10111e1 , 0xe10111f3 , 0xe1011101 , 0xe1011111 ,
795
- 0xe111e1e1 , 0xe111e1f3 , 0xe111e101 , 0xe111e111 , 0xe111f3e1 , 0xe111f3f3 , 0xe111f301 , 0xe111f311 ,
796
- 0xe11101e1 , 0xe11101f3 , 0xe1110101 , 0xe1110111 , 0xe11111e1 , 0xe11111f3 , 0xe1111101 , 0xe1111111 ,
797
- 0xf3e1e1e1 , 0xf3e1e1f3 , 0xf3e1e101 , 0xf3e1e111 , 0xf3e1f3e1 , 0xf3e1f3f3 , 0xf3e1f301 , 0xf3e1f311 ,
798
- 0xf3e101e1 , 0xf3e101f3 , 0xf3e10101 , 0xf3e10111 , 0xf3e111e1 , 0xf3e111f3 , 0xf3e11101 , 0xf3e11111 ,
799
- 0xf3f3e1e1 , 0xf3f3e1f3 , 0xf3f3e101 , 0xf3f3e111 , 0xf3f3f3e1 , 0xf3f3f3f3 , 0xf3f3f301 , 0xf3f3f311 ,
800
- 0xf3f301e1 , 0xf3f301f3 , 0xf3f30101 , 0xf3f30111 , 0xf3f311e1 , 0xf3f311f3 , 0xf3f31101 , 0xf3f31111 ,
801
- 0xf301e1e1 , 0xf301e1f3 , 0xf301e101 , 0xf301e111 , 0xf301f3e1 , 0xf301f3f3 , 0xf301f301 , 0xf301f311 ,
802
- 0xf30101e1 , 0xf30101f3 , 0xf3010101 , 0xf3010111 , 0xf30111e1 , 0xf30111f3 , 0xf3011101 , 0xf3011111 ,
803
- 0xf311e1e1 , 0xf311e1f3 , 0xf311e101 , 0xf311e111 , 0xf311f3e1 , 0xf311f3f3 , 0xf311f301 , 0xf311f311 ,
804
- 0xf31101e1 , 0xf31101f3 , 0xf3110101 , 0xf3110111 , 0xf31111e1 , 0xf31111f3 , 0xf3111101 , 0xf3111111 ,
805
- 0x01e1e1e1 , 0x01e1e1f3 , 0x01e1e101 , 0x01e1e111 , 0x01e1f3e1 , 0x01e1f3f3 , 0x01e1f301 , 0x01e1f311 ,
806
- 0x01e101e1 , 0x01e101f3 , 0x01e10101 , 0x01e10111 , 0x01e111e1 , 0x01e111f3 , 0x01e11101 , 0x01e11111 ,
807
- 0x01f3e1e1 , 0x01f3e1f3 , 0x01f3e101 , 0x01f3e111 , 0x01f3f3e1 , 0x01f3f3f3 , 0x01f3f301 , 0x01f3f311 ,
808
- 0x01f301e1 , 0x01f301f3 , 0x01f30101 , 0x01f30111 , 0x01f311e1 , 0x01f311f3 , 0x01f31101 , 0x01f31111 ,
809
- 0x0101e1e1 , 0x0101e1f3 , 0x0101e101 , 0x0101e111 , 0x0101f3e1 , 0x0101f3f3 , 0x0101f301 , 0x0101f311 ,
810
- 0x010101e1 , 0x010101f3 , 0x01010101 , 0x01010111 , 0x010111e1 , 0x010111f3 , 0x01011101 , 0x01011111 ,
811
- 0x0111e1e1 , 0x0111e1f3 , 0x0111e101 , 0x0111e111 , 0x0111f3e1 , 0x0111f3f3 , 0x0111f301 , 0x0111f311 ,
812
- 0x011101e1 , 0x011101f3 , 0x01110101 , 0x01110111 , 0x011111e1 , 0x011111f3 , 0x01111101 , 0x01111111 ,
813
- 0x11e1e1e1 , 0x11e1e1f3 , 0x11e1e101 , 0x11e1e111 , 0x11e1f3e1 , 0x11e1f3f3 , 0x11e1f301 , 0x11e1f311 ,
814
- 0x11e101e1 , 0x11e101f3 , 0x11e10101 , 0x11e10111 , 0x11e111e1 , 0x11e111f3 , 0x11e11101 , 0x11e11111 ,
815
- 0x11f3e1e1 , 0x11f3e1f3 , 0x11f3e101 , 0x11f3e111 , 0x11f3f3e1 , 0x11f3f3f3 , 0x11f3f301 , 0x11f3f311 ,
816
- 0x11f301e1 , 0x11f301f3 , 0x11f30101 , 0x11f30111 , 0x11f311e1 , 0x11f311f3 , 0x11f31101 , 0x11f31111 ,
817
- 0x1101e1e1 , 0x1101e1f3 , 0x1101e101 , 0x1101e111 , 0x1101f3e1 , 0x1101f3f3 , 0x1101f301 , 0x1101f311 ,
818
- 0x110101e1 , 0x110101f3 , 0x11010101 , 0x11010111 , 0x110111e1 , 0x110111f3 , 0x11011101 , 0x11011111 ,
819
- 0x1111e1e1 , 0x1111e1f3 , 0x1111e101 , 0x1111e111 , 0x1111f3e1 , 0x1111f3f3 , 0x1111f301 , 0x1111f311 ,
820
- 0x111101e1 , 0x111101f3 , 0x11110101 , 0x11110111 , 0x111111e1 , 0x111111f3 , 0x11111101 , 0x11111111 ,
821
- 0xe6e6e6e6 , 0xe6e6e6f8 , 0xe6e6e606 , 0xe6e6e616 , 0xe6e6f8e6 , 0xe6e6f8f8 , 0xe6e6f806 , 0xe6e6f816 ,
822
- 0xe6e606e6 , 0xe6e606f8 , 0xe6e60606 , 0xe6e60616 , 0xe6e616e6 , 0xe6e616f8 , 0xe6e61606 , 0xe6e61616 ,
823
- 0xe6f8e6e6 , 0xe6f8e6f8 , 0xe6f8e606 , 0xe6f8e616 , 0xe6f8f8e6 , 0xe6f8f8f8 , 0xe6f8f806 , 0xe6f8f816 ,
824
- 0xe6f806e6 , 0xe6f806f8 , 0xe6f80606 , 0xe6f80616 , 0xe6f816e6 , 0xe6f816f8 , 0xe6f81606 , 0xe6f81616 ,
825
- 0xe606e6e6 , 0xe606e6f8 , 0xe606e606 , 0xe606e616 , 0xe606f8e6 , 0xe606f8f8 , 0xe606f806 , 0xe606f816 ,
826
- 0xe60606e6 , 0xe60606f8 , 0xe6060606 , 0xe6060616 , 0xe60616e6 , 0xe60616f8 , 0xe6061606 , 0xe6061616 ,
827
- 0xe616e6e6 , 0xe616e6f8 , 0xe616e606 , 0xe616e616 , 0xe616f8e6 , 0xe616f8f8 , 0xe616f806 , 0xe616f816 ,
828
- 0xe61606e6 , 0xe61606f8 , 0xe6160606 , 0xe6160616 , 0xe61616e6 , 0xe61616f8 , 0xe6161606 , 0xe6161616 ,
829
- 0xf8e6e6e6 , 0xf8e6e6f8 , 0xf8e6e606 , 0xf8e6e616 , 0xf8e6f8e6 , 0xf8e6f8f8 , 0xf8e6f806 , 0xf8e6f816 ,
830
- 0xf8e606e6 , 0xf8e606f8 , 0xf8e60606 , 0xf8e60616 , 0xf8e616e6 , 0xf8e616f8 , 0xf8e61606 , 0xf8e61616 ,
831
- 0xf8f8e6e6 , 0xf8f8e6f8 , 0xf8f8e606 , 0xf8f8e616 , 0xf8f8f8e6 , 0xf8f8f8f8 , 0xf8f8f806 , 0xf8f8f816 ,
832
- 0xf8f806e6 , 0xf8f806f8 , 0xf8f80606 , 0xf8f80616 , 0xf8f816e6 , 0xf8f816f8 , 0xf8f81606 , 0xf8f81616 ,
833
- 0xf806e6e6 , 0xf806e6f8 , 0xf806e606 , 0xf806e616 , 0xf806f8e6 , 0xf806f8f8 , 0xf806f806 , 0xf806f816 ,
834
- 0xf80606e6 , 0xf80606f8 , 0xf8060606 , 0xf8060616 , 0xf80616e6 , 0xf80616f8 , 0xf8061606 , 0xf8061616 ,
835
- 0xf816e6e6 , 0xf816e6f8 , 0xf816e606 , 0xf816e616 , 0xf816f8e6 , 0xf816f8f8 , 0xf816f806 , 0xf816f816 ,
836
- 0xf81606e6 , 0xf81606f8 , 0xf8160606 , 0xf8160616 , 0xf81616e6 , 0xf81616f8 , 0xf8161606 , 0xf8161616 ,
837
- 0x06e6e6e6 , 0x06e6e6f8 , 0x06e6e606 , 0x06e6e616 , 0x06e6f8e6 , 0x06e6f8f8 , 0x06e6f806 , 0x06e6f816 ,
838
- 0x06e606e6 , 0x06e606f8 , 0x06e60606 , 0x06e60616 , 0x06e616e6 , 0x06e616f8 , 0x06e61606 , 0x06e61616 ,
839
- 0x06f8e6e6 , 0x06f8e6f8 , 0x06f8e606 , 0x06f8e616 , 0x06f8f8e6 , 0x06f8f8f8 , 0x06f8f806 , 0x06f8f816 ,
840
- 0x06f806e6 , 0x06f806f8 , 0x06f80606 , 0x06f80616 , 0x06f816e6 , 0x06f816f8 , 0x06f81606 , 0x06f81616 ,
841
- 0x0606e6e6 , 0x0606e6f8 , 0x0606e606 , 0x0606e616 , 0x0606f8e6 , 0x0606f8f8 , 0x0606f806 , 0x0606f816 ,
842
- 0x060606e6 , 0x060606f8 , 0x06060606 , 0x06060616 , 0x060616e6 , 0x060616f8 , 0x06061606 , 0x06061616 ,
843
- 0x0616e6e6 , 0x0616e6f8 , 0x0616e606 , 0x0616e616 , 0x0616f8e6 , 0x0616f8f8 , 0x0616f806 , 0x0616f816 ,
844
- 0x061606e6 , 0x061606f8 , 0x06160606 , 0x06160616 , 0x061616e6 , 0x061616f8 , 0x06161606 , 0x06161616 ,
845
- 0x16e6e6e6 , 0x16e6e6f8 , 0x16e6e606 , 0x16e6e616 , 0x16e6f8e6 , 0x16e6f8f8 , 0x16e6f806 , 0x16e6f816 ,
846
- 0x16e606e6 , 0x16e606f8 , 0x16e60606 , 0x16e60616 , 0x16e616e6 , 0x16e616f8 , 0x16e61606 , 0x16e61616 ,
847
- 0x16f8e6e6 , 0x16f8e6f8 , 0x16f8e606 , 0x16f8e616 , 0x16f8f8e6 , 0x16f8f8f8 , 0x16f8f806 , 0x16f8f816 ,
848
- 0x16f806e6 , 0x16f806f8 , 0x16f80606 , 0x16f80616 , 0x16f816e6 , 0x16f816f8 , 0x16f81606 , 0x16f81616 ,
849
- 0x1606e6e6 , 0x1606e6f8 , 0x1606e606 , 0x1606e616 , 0x1606f8e6 , 0x1606f8f8 , 0x1606f806 , 0x1606f816 ,
850
- 0x160606e6 , 0x160606f8 , 0x16060606 , 0x16060616 , 0x160616e6 , 0x160616f8 , 0x16061606 , 0x16061616 ,
851
- 0x1616e6e6 , 0x1616e6f8 , 0x1616e606 , 0x1616e616 , 0x1616f8e6 , 0x1616f8f8 , 0x1616f806 , 0x1616f816 ,
852
- 0x161606e6 , 0x161606f8 , 0x16160606 , 0x16160616 , 0x161616e6 , 0x161616f8 , 0x16161606 , 0x16161616 ,
853
- };
854
-
855
- __device__ __forceinline__ int int_from_table_4 (const uint8_t * a8, const int * values) {
856
- return values[a8[0 ] | (a8[1 ] << 2 ) | (a8[2 ] << 4 ) | (a8[3 ] << 6 )];
857
- }
858
-
859
789
#define VDR_IQ2_K_Q8_1_MMVQ 4
860
790
#define VDR_IQ2_K_Q8_1_MMQ 4
861
791
@@ -881,7 +811,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
881
811
uint32_t val1 = q2[0 ], val2 = q2[1 ];
882
812
883
813
uint32_t aux32[2 ];
884
- const uint8_t * a8 = (const uint8_t *)&aux32;
885
814
int v1, v2;
886
815
887
816
// Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
@@ -892,23 +821,23 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
892
821
const int8_t * s8 = (const int8_t *)&s32;
893
822
894
823
aux32[0 ] = ((val1 >> 0 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 0 ) & 0x03030303 ); values = all_values + ((extra & 0x01 ) << 8 );
895
- v1 = int_from_table_4 (a8 + 0 , values);
896
- v2 = int_from_table_4 (a8 + 4 , values);
824
+ v1 = int_from_table_4 (aux32[ 0 ] , values);
825
+ v2 = int_from_table_4 (aux32[ 1 ] , values);
897
826
int sumi1 = ggml_cuda_dp4a (v2, q8_1[1 ], ggml_cuda_dp4a (v1, q8_1[0 ], 0 )) * s8[0 ];
898
827
899
828
aux32[0 ] = ((val1 >> 2 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 2 ) & 0x03030303 ); values = all_values + ((extra & 0x04 ) << 6 );
900
- v1 = int_from_table_4 (a8 + 0 , values);
901
- v2 = int_from_table_4 (a8 + 4 , values);
829
+ v1 = int_from_table_4 (aux32[ 0 ] , values);
830
+ v2 = int_from_table_4 (aux32[ 1 ] , values);
902
831
int sumi2 = ggml_cuda_dp4a (v2, q8_2[1 ], ggml_cuda_dp4a (v1, q8_2[0 ], 0 )) * s8[1 ];
903
832
904
833
aux32[0 ] = ((val1 >> 4 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 4 ) & 0x03030303 ); values = all_values + ((extra & 0x10 ) << 4 );
905
- v1 = int_from_table_4 (a8 + 0 , values);
906
- v2 = int_from_table_4 (a8 + 4 , values);
834
+ v1 = int_from_table_4 (aux32[ 0 ] , values);
835
+ v2 = int_from_table_4 (aux32[ 1 ] , values);
907
836
int sumi3 = ggml_cuda_dp4a (v2, q8_3[1 ], ggml_cuda_dp4a (v1, q8_3[0 ], 0 )) * s8[2 ];
908
837
909
838
aux32[0 ] = ((val1 >> 6 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 6 ) & 0x03030303 ); values = all_values + ((extra & 0x40 ) << 2 );
910
- v1 = int_from_table_4 (a8 + 0 , values);
911
- v2 = int_from_table_4 (a8 + 4 , values);
839
+ v1 = int_from_table_4 (aux32[ 0 ] , values);
840
+ v2 = int_from_table_4 (aux32[ 1 ] , values);
912
841
int sumi4 = ggml_cuda_dp4a (v2, q8_4[1 ], ggml_cuda_dp4a (v1, q8_4[0 ], 0 )) * s8[3 ];
913
842
914
843
*result += __half2float (bq2->d ) * (__low2float (bq8_1[4 *(i4/4 )+0 ].ds ) * sumi1
@@ -941,7 +870,6 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
941
870
uint32_t val1 = q2[0 ] | (q2[1 ] << 16 ), val2 = q2[2 ] | (q2[3 ] << 16 );
942
871
943
872
uint32_t aux32[2 ];
944
- const uint8_t * a8 = (const uint8_t *)&aux32;
945
873
int v1, v2;
946
874
947
875
int32_t scales32;
@@ -954,23 +882,23 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
954
882
s8[3 ] += ((extra >> 7 ) & 0x10 );
955
883
956
884
aux32[0 ] = ((val1 >> 0 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 0 ) & 0x03030303 ); values = all_values + ((extra & 0x01 ) << 8 );
957
- v1 = int_from_table_4 (a8 + 0 , values);
958
- v2 = int_from_table_4 (a8 + 4 , values);
885
+ v1 = int_from_table_4 (aux32[ 0 ] , values);
886
+ v2 = int_from_table_4 (aux32[ 1 ] , values);
959
887
int sumi1 = ggml_cuda_dp4a (v2, q8_1[1 ], ggml_cuda_dp4a (v1, q8_1[0 ], 0 )) * s8[0 ];
960
888
961
889
aux32[0 ] = ((val1 >> 2 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 2 ) & 0x03030303 ); values = all_values + ((extra & 0x02 ) << 7 );
962
- v1 = int_from_table_4 (a8 + 0 , values);
963
- v2 = int_from_table_4 (a8 + 4 , values);
890
+ v1 = int_from_table_4 (aux32[ 0 ] , values);
891
+ v2 = int_from_table_4 (aux32[ 1 ] , values);
964
892
int sumi2 = ggml_cuda_dp4a (v2, q8_2[1 ], ggml_cuda_dp4a (v1, q8_2[0 ], 0 )) * s8[2 ];
965
893
966
894
aux32[0 ] = ((val1 >> 4 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 4 ) & 0x03030303 ); values = all_values + ((extra & 0x04 ) << 6 );
967
- v1 = int_from_table_4 (a8 + 0 , values);
968
- v2 = int_from_table_4 (a8 + 4 , values);
895
+ v1 = int_from_table_4 (aux32[ 0 ] , values);
896
+ v2 = int_from_table_4 (aux32[ 1 ] , values);
969
897
int sumi3 = ggml_cuda_dp4a (v2, q8_3[1 ], ggml_cuda_dp4a (v1, q8_3[0 ], 0 )) * s8[1 ];
970
898
971
899
aux32[0 ] = ((val1 >> 6 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 6 ) & 0x03030303 ); values = all_values + ((extra & 0x08 ) << 5 );
972
- v1 = int_from_table_4 (a8 + 0 , values);
973
- v2 = int_from_table_4 (a8 + 4 , values);
900
+ v1 = int_from_table_4 (aux32[ 0 ] , values);
901
+ v2 = int_from_table_4 (aux32[ 1 ] , values);
974
902
int sumi4 = ggml_cuda_dp4a (v2, q8_4[1 ], ggml_cuda_dp4a (v1, q8_4[0 ], 0 )) * s8[3 ];
975
903
976
904
*result += scale * (__low2float (bq8_1[4 *(i4/4 )+0 ].ds ) * sumi1
@@ -1000,20 +928,19 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
1000
928
int2 val1;
1001
929
const int * q2 = (const int *)bq2->qs + 8 *ib32 + 4 *is;
1002
930
int aux32[2 ];
1003
- const uint8_t * aux8 = (const uint8_t *)aux32;
1004
931
#pragma unroll
1005
932
for (int i = 0 ; i < 4 ; ++i) {
1006
933
auto values1 = all_values + (((bq2->extra [i+4 *is] >> ib32) & 1 ) << 8 );
1007
934
int sumi1 = 0 ;
1008
935
aux32[0 ] = ((q2[i] >> 0 ) & 0x03030303 );
1009
936
aux32[1 ] = ((q2[i] >> 2 ) & 0x03030303 );
1010
- val1.x = int_from_table_4 (aux8+ 0 , values1);
1011
- val1.y = int_from_table_4 (aux8+ 4 , values1);
937
+ val1.x = int_from_table_4 (aux32[ 0 ] , values1);
938
+ val1.y = int_from_table_4 (aux32[ 1 ] , values1);
1012
939
sumi1 = ggml_cuda_dp4a (val1.x , q8[0 ], ggml_cuda_dp4a (val1.y , q8[1 ], sumi1));
1013
940
aux32[0 ] = ((q2[i] >> 4 ) & 0x03030303 );
1014
941
aux32[1 ] = ((q2[i] >> 6 ) & 0x03030303 );
1015
- val1.x = int_from_table_4 (aux8+ 0 , values1);
1016
- val1.y = int_from_table_4 (aux8+ 4 , values1);
942
+ val1.x = int_from_table_4 (aux32[ 0 ] , values1);
943
+ val1.y = int_from_table_4 (aux32[ 1 ] , values1);
1017
944
sumi1 = ggml_cuda_dp4a (val1.x , q8[2 ], ggml_cuda_dp4a (val1.y , q8[3 ], sumi1));
1018
945
const float d = __half2float (bq2->d [i]) * d8;
1019
946
result[i] += d * sumi1 * s8[i];
@@ -1114,7 +1041,6 @@ __device__ __forceinline__ void vec_dot_iq3_ks_q8_1(
1114
1041
const int ib128 = iqs/4 ; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255
1115
1042
// Each thread processes 8 quants in each of the 4 32-blocks
1116
1043
const int il8 = iqs%4 ; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
1117
- const int shift = 4 *(il8/2 );
1118
1044
1119
1045
const uint16_t * ql = (const uint16_t *)bq3->qs + 16 *ib128 + 4 *il8;
1120
1046
const uint16_t * qh = (const uint16_t *)bq3->qh + 4 *il8;
0 commit comments