Skip to content

Commit 283753c

Browse files
KawrakowIwan Kawrakow
andauthored
CUDA: Faster prompt processing for several quantization types (ikawrakow#595)
* cuda: slightly faster MMQ for iq3_k, iq3_k_r4 * cuda: slightly faster MMQ for iq4_k, iq4_k_r4 * cuda: slightly faster MMQ for iq4_ks_r4 * cuda: slightly faster MMQ for iq4_ks * cuda: slightly faster MMQ for iq4_xs --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 5446ccc commit 283753c

File tree

5 files changed

+139
-107
lines changed

5 files changed

+139
-107
lines changed

ggml/src/ggml-cuda/iqk_cuda_common.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,57 @@ __device__ __forceinline__ int int_from_table_4(const uint32_t idx, const int *
7373
return values[ggml_cuda_dp4a(idx, 0x40100401, 0)];
7474
}
7575

76+
static const __device__ uint16_t iq3k_table[128] = {
77+
0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
78+
0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
79+
0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
80+
0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
81+
0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
82+
0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
83+
0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
84+
0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
85+
};
86+
87+
__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
88+
return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
89+
}
90+
91+
static const __device__ uint16_t iq4k_table[512] = {
92+
0x8181, 0x8198, 0x81ad, 0x81bf, 0x81cf, 0x81dd, 0x81ea, 0x81f6, 0x8101, 0x810d, 0x8119, 0x8126, 0x8135, 0x8145, 0x8159, 0x8171,
93+
0x9881, 0x9898, 0x98ad, 0x98bf, 0x98cf, 0x98dd, 0x98ea, 0x98f6, 0x9801, 0x980d, 0x9819, 0x9826, 0x9835, 0x9845, 0x9859, 0x9871,
94+
0xad81, 0xad98, 0xadad, 0xadbf, 0xadcf, 0xaddd, 0xadea, 0xadf6, 0xad01, 0xad0d, 0xad19, 0xad26, 0xad35, 0xad45, 0xad59, 0xad71,
95+
0xbf81, 0xbf98, 0xbfad, 0xbfbf, 0xbfcf, 0xbfdd, 0xbfea, 0xbff6, 0xbf01, 0xbf0d, 0xbf19, 0xbf26, 0xbf35, 0xbf45, 0xbf59, 0xbf71,
96+
0xcf81, 0xcf98, 0xcfad, 0xcfbf, 0xcfcf, 0xcfdd, 0xcfea, 0xcff6, 0xcf01, 0xcf0d, 0xcf19, 0xcf26, 0xcf35, 0xcf45, 0xcf59, 0xcf71,
97+
0xdd81, 0xdd98, 0xddad, 0xddbf, 0xddcf, 0xdddd, 0xddea, 0xddf6, 0xdd01, 0xdd0d, 0xdd19, 0xdd26, 0xdd35, 0xdd45, 0xdd59, 0xdd71,
98+
0xea81, 0xea98, 0xeaad, 0xeabf, 0xeacf, 0xeadd, 0xeaea, 0xeaf6, 0xea01, 0xea0d, 0xea19, 0xea26, 0xea35, 0xea45, 0xea59, 0xea71,
99+
0xf681, 0xf698, 0xf6ad, 0xf6bf, 0xf6cf, 0xf6dd, 0xf6ea, 0xf6f6, 0xf601, 0xf60d, 0xf619, 0xf626, 0xf635, 0xf645, 0xf659, 0xf671,
100+
0x0181, 0x0198, 0x01ad, 0x01bf, 0x01cf, 0x01dd, 0x01ea, 0x01f6, 0x0101, 0x010d, 0x0119, 0x0126, 0x0135, 0x0145, 0x0159, 0x0171,
101+
0x0d81, 0x0d98, 0x0dad, 0x0dbf, 0x0dcf, 0x0ddd, 0x0dea, 0x0df6, 0x0d01, 0x0d0d, 0x0d19, 0x0d26, 0x0d35, 0x0d45, 0x0d59, 0x0d71,
102+
0x1981, 0x1998, 0x19ad, 0x19bf, 0x19cf, 0x19dd, 0x19ea, 0x19f6, 0x1901, 0x190d, 0x1919, 0x1926, 0x1935, 0x1945, 0x1959, 0x1971,
103+
0x2681, 0x2698, 0x26ad, 0x26bf, 0x26cf, 0x26dd, 0x26ea, 0x26f6, 0x2601, 0x260d, 0x2619, 0x2626, 0x2635, 0x2645, 0x2659, 0x2671,
104+
0x3581, 0x3598, 0x35ad, 0x35bf, 0x35cf, 0x35dd, 0x35ea, 0x35f6, 0x3501, 0x350d, 0x3519, 0x3526, 0x3535, 0x3545, 0x3559, 0x3571,
105+
0x4581, 0x4598, 0x45ad, 0x45bf, 0x45cf, 0x45dd, 0x45ea, 0x45f6, 0x4501, 0x450d, 0x4519, 0x4526, 0x4535, 0x4545, 0x4559, 0x4571,
106+
0x5981, 0x5998, 0x59ad, 0x59bf, 0x59cf, 0x59dd, 0x59ea, 0x59f6, 0x5901, 0x590d, 0x5919, 0x5926, 0x5935, 0x5945, 0x5959, 0x5971,
107+
0x7181, 0x7198, 0x71ad, 0x71bf, 0x71cf, 0x71dd, 0x71ea, 0x71f6, 0x7101, 0x710d, 0x7119, 0x7126, 0x7135, 0x7145, 0x7159, 0x7171,
108+
0x8585, 0x859c, 0x85b1, 0x85c3, 0x85d3, 0x85e1, 0x85ee, 0x85fa, 0x8505, 0x8511, 0x851d, 0x852a, 0x8539, 0x8549, 0x855d, 0x8575,
109+
0x9c85, 0x9c9c, 0x9cb1, 0x9cc3, 0x9cd3, 0x9ce1, 0x9cee, 0x9cfa, 0x9c05, 0x9c11, 0x9c1d, 0x9c2a, 0x9c39, 0x9c49, 0x9c5d, 0x9c75,
110+
0xb185, 0xb19c, 0xb1b1, 0xb1c3, 0xb1d3, 0xb1e1, 0xb1ee, 0xb1fa, 0xb105, 0xb111, 0xb11d, 0xb12a, 0xb139, 0xb149, 0xb15d, 0xb175,
111+
0xc385, 0xc39c, 0xc3b1, 0xc3c3, 0xc3d3, 0xc3e1, 0xc3ee, 0xc3fa, 0xc305, 0xc311, 0xc31d, 0xc32a, 0xc339, 0xc349, 0xc35d, 0xc375,
112+
0xd385, 0xd39c, 0xd3b1, 0xd3c3, 0xd3d3, 0xd3e1, 0xd3ee, 0xd3fa, 0xd305, 0xd311, 0xd31d, 0xd32a, 0xd339, 0xd349, 0xd35d, 0xd375,
113+
0xe185, 0xe19c, 0xe1b1, 0xe1c3, 0xe1d3, 0xe1e1, 0xe1ee, 0xe1fa, 0xe105, 0xe111, 0xe11d, 0xe12a, 0xe139, 0xe149, 0xe15d, 0xe175,
114+
0xee85, 0xee9c, 0xeeb1, 0xeec3, 0xeed3, 0xeee1, 0xeeee, 0xeefa, 0xee05, 0xee11, 0xee1d, 0xee2a, 0xee39, 0xee49, 0xee5d, 0xee75,
115+
0xfa85, 0xfa9c, 0xfab1, 0xfac3, 0xfad3, 0xfae1, 0xfaee, 0xfafa, 0xfa05, 0xfa11, 0xfa1d, 0xfa2a, 0xfa39, 0xfa49, 0xfa5d, 0xfa75,
116+
0x0585, 0x059c, 0x05b1, 0x05c3, 0x05d3, 0x05e1, 0x05ee, 0x05fa, 0x0505, 0x0511, 0x051d, 0x052a, 0x0539, 0x0549, 0x055d, 0x0575,
117+
0x1185, 0x119c, 0x11b1, 0x11c3, 0x11d3, 0x11e1, 0x11ee, 0x11fa, 0x1105, 0x1111, 0x111d, 0x112a, 0x1139, 0x1149, 0x115d, 0x1175,
118+
0x1d85, 0x1d9c, 0x1db1, 0x1dc3, 0x1dd3, 0x1de1, 0x1dee, 0x1dfa, 0x1d05, 0x1d11, 0x1d1d, 0x1d2a, 0x1d39, 0x1d49, 0x1d5d, 0x1d75,
119+
0x2a85, 0x2a9c, 0x2ab1, 0x2ac3, 0x2ad3, 0x2ae1, 0x2aee, 0x2afa, 0x2a05, 0x2a11, 0x2a1d, 0x2a2a, 0x2a39, 0x2a49, 0x2a5d, 0x2a75,
120+
0x3985, 0x399c, 0x39b1, 0x39c3, 0x39d3, 0x39e1, 0x39ee, 0x39fa, 0x3905, 0x3911, 0x391d, 0x392a, 0x3939, 0x3949, 0x395d, 0x3975,
121+
0x4985, 0x499c, 0x49b1, 0x49c3, 0x49d3, 0x49e1, 0x49ee, 0x49fa, 0x4905, 0x4911, 0x491d, 0x492a, 0x4939, 0x4949, 0x495d, 0x4975,
122+
0x5d85, 0x5d9c, 0x5db1, 0x5dc3, 0x5dd3, 0x5de1, 0x5dee, 0x5dfa, 0x5d05, 0x5d11, 0x5d1d, 0x5d2a, 0x5d39, 0x5d49, 0x5d5d, 0x5d75,
123+
0x7585, 0x759c, 0x75b1, 0x75c3, 0x75d3, 0x75e1, 0x75ee, 0x75fa, 0x7505, 0x7511, 0x751d, 0x752a, 0x7539, 0x7549, 0x755d, 0x7575,
124+
};
125+
126+
__device__ __forceinline__ int int_from_table_x(const uint8_t * a8, const uint16_t * values) {
127+
return values[a8[0] | (a8[1] << 4)] | (values[a8[2] | (a8[3] << 4)] << 16);
128+
}
129+

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -950,21 +950,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
950950
#define VDR_IQ3_K_Q8_1_MMVQ 4
951951
#define VDR_IQ3_K_Q8_1_MMQ 4
952952

953-
static const __device__ uint16_t iq3k_table[128] = {
954-
0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
955-
0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
956-
0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
957-
0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
958-
0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
959-
0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
960-
0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
961-
0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
962-
};
963-
964-
__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
965-
return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
966-
}
967-
968953
__device__ __forceinline__ void vec_dot_iq3_k_q8_1(
969954
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
970955
const block_iq3_k * bq3 = (const block_iq3_k *) vbq + kbx;

0 commit comments

Comments
 (0)