Skip to content

Commit 79cde2c

Browse files
committed
ggml-cuda : add TQ2_0 kernels, for ternary inference on GPU ggml-org#11183
Credit : @compilade
1 parent 7d91690 commit 79cde2c

File tree

10 files changed

+195
-0
lines changed

10 files changed

+195
-0
lines changed

ggml/src/ggml-common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ typedef sycl::half2 ggml_half2;
129129
#define QI6_K (QK_K / (4*QR6_K))
130130
#define QR6_K 2
131131

132+
#define QI2_0 (QK_K / (4*QR2_0))
133+
#define QR2_0 4
134+
132135
#define QI2_XXS (QK_K / (4*QR2_XXS))
133136
#define QR2_XXS 4
134137

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
453453
static constexpr int qi = QI6_K;
454454
};
455455

456+
template<>
457+
struct ggml_cuda_type_traits<GGML_TYPE_TQ2_0> {
458+
static constexpr int qk = QK_K;
459+
static constexpr int qr = QR2_0;
460+
static constexpr int qi = QI2_0;
461+
};
462+
456463
template<>
457464
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
458465
static constexpr int qk = QK_K;

ggml/src/ggml-cuda/convert.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,26 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
311311
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
312312
}
313313

314+
template<typename dst_t>
315+
static __global__ void dequantize_block_tq2_0(const void * __restrict__ vx, dst_t * __restrict__ yy) {
316+
317+
const int64_t i = blockIdx.x;
318+
const block_tq2_0 * x = (const block_tq2_0 *) vx;
319+
320+
const int64_t tid = threadIdx.x; // 0..64
321+
const int64_t n = tid/32; // 0 or 1
322+
const int64_t l = tid - 32*n; // 0..32
323+
324+
const uint8_t q = x[i].qs[32*n + l];
325+
dst_t * y = yy + i*QK_K + 128*n;
326+
327+
float d = __half2float(x[i].d);
328+
y[l+ 0] = d * ((q >> 0) & 3) - d;
329+
y[l+32] = d * ((q >> 2) & 3) - d;
330+
y[l+64] = d * ((q >> 4) & 3) - d;
331+
y[l+96] = d * ((q >> 6) & 3) - d;
332+
}
333+
314334
template<typename dst_t>
315335
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
316336

@@ -1008,6 +1028,13 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
10081028
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
10091029
}
10101030

1031+
template<typename dst_t>
1032+
static void dequantize_row_tq2_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1033+
const int64_t k = nrows * n_per_row;
1034+
const int nb = k / QK_K;
1035+
dequantize_block_tq2_0<<<nb, 64, 0, stream>>>(vx, y);
1036+
}
1037+
10111038
template<typename dst_t>
10121039
static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
10131040
const int64_t k = nrows * n_per_row;
@@ -1268,6 +1295,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
12681295
return dequantize_row_q5_K_cuda;
12691296
case GGML_TYPE_Q6_K:
12701297
return dequantize_row_q6_K_cuda;
1298+
case GGML_TYPE_TQ2_0:
1299+
return dequantize_row_tq2_0_cuda;
12711300
case GGML_TYPE_IQ2_XXS:
12721301
return dequantize_row_iq2_xxs_cuda;
12731302
case GGML_TYPE_IQ2_KT:
@@ -1345,6 +1374,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
13451374
return dequantize_row_q5_K_cuda;
13461375
case GGML_TYPE_Q6_K:
13471376
return dequantize_row_q6_K_cuda;
1377+
case GGML_TYPE_TQ2_0:
1378+
return dequantize_row_tq2_0_cuda;
13481379
case GGML_TYPE_IQ2_XXS:
13491380
return dequantize_row_iq2_xxs_cuda;
13501381
case GGML_TYPE_IQ2_KT:

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3112,6 +3112,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31123112
case GGML_TYPE_Q5_K:
31133113
case GGML_TYPE_Q6_K:
31143114
case GGML_TYPE_Q8_K:
3115+
case GGML_TYPE_TQ2_0:
31153116
case GGML_TYPE_IQ1_M:
31163117
case GGML_TYPE_IQ1_S:
31173118
case GGML_TYPE_IQ2_S:

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ void ggml_cuda_op_mul_mat_q(
6464
case GGML_TYPE_Q6_K:
6565
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
6666
break;
67+
case GGML_TYPE_TQ2_0:
68+
mul_mat_q_case<GGML_TYPE_TQ2_0>(ctx, args, stream);
69+
break;
6770
case GGML_TYPE_IQ2_XXS:
6871
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
6972
break;
@@ -119,6 +122,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
119122
case GGML_TYPE_Q4_K:
120123
case GGML_TYPE_Q5_K:
121124
case GGML_TYPE_Q6_K:
125+
case GGML_TYPE_TQ2_0:
122126
case GGML_TYPE_IQ2_XXS:
123127
case GGML_TYPE_IQ2_XS:
124128
case GGML_TYPE_IQ2_S:

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
6666
case GGML_TYPE_Q5_K:
6767
return MMQ_Q8_1_DS_LAYOUT_DS4;
6868
case GGML_TYPE_Q6_K:
69+
case GGML_TYPE_TQ2_0:
6970
case GGML_TYPE_IQ2_XXS:
7071
case GGML_TYPE_IQ2_XS:
7172
case GGML_TYPE_IQ2_S:
@@ -165,6 +166,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
165166
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
166167
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
167168
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
169+
type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 :
168170
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
169171
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
170172
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
@@ -200,6 +202,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
200202
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
201203
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
202204
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
205+
type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
203206
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
204207
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
205208
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
@@ -1876,6 +1879,68 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
18761879
#endif // INT8_MMA_AVAILABLE
18771880
}
18781881

1882+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0(
1883+
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1884+
1885+
#ifdef INT8_MMA_AVAILABLE
1886+
int * x_qs = (int *) x_tile;
1887+
float * x_df = (float *) (x_tile + 2*WARP_SIZE);
1888+
#else
1889+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
1890+
int * x_qs = (int *) x_tile;
1891+
float * x_df = (float *) (x_qs + txs.qs);
1892+
#endif // INT8_MMA_AVAILABLE
1893+
1894+
const int kqsx = threadIdx.x % QI2_0;
1895+
1896+
#pragma unroll
1897+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_0) {
1898+
int i = i0 + threadIdx.y*(WARP_SIZE/QI2_0) + threadIdx.x/QI2_0;
1899+
1900+
if (need_check) {
1901+
i = min(i, i_max);
1902+
}
1903+
1904+
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
1905+
const int qs0 = get_int_b2(bxi->qs, kqsx);
1906+
1907+
#pragma unroll
1908+
for (int l = 0; l < QR2_0; ++l) {
1909+
// 0..7, 32..39
1910+
// 8..15, 40..47
1911+
// 16..23, 48..55
1912+
// 24..31, 56..63
1913+
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1914+
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
1915+
1916+
#ifdef INT8_MMA_AVAILABLE
1917+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q;
1918+
#else
1919+
x_qs[i*(2*WARP_SIZE + 1) + k] = q;
1920+
#endif // INT8_MMA_AVAILABLE
1921+
}
1922+
}
1923+
1924+
#pragma unroll
1925+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_0/2)) {
1926+
int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_0) + threadIdx.x/(QI2_0/2);
1927+
1928+
if (need_check) {
1929+
i = min(i, i_max);
1930+
}
1931+
1932+
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
1933+
1934+
const int k = threadIdx.x % (QI2_0/2);
1935+
1936+
#ifdef INT8_MMA_AVAILABLE
1937+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d;
1938+
#else
1939+
x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d;
1940+
#endif // INT8_MMA_AVAILABLE
1941+
}
1942+
}
1943+
18791944
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
18801945
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
18811946

@@ -2503,6 +2568,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
25032568
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
25042569
};
25052570

2571+
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2572+
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
2573+
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
2574+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
2575+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2576+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2577+
};
2578+
25062579
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
25072580
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
25082581
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
@@ -2993,6 +3066,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
29933066
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
29943067
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
29953068
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
3069+
extern DECL_MMQ_CASE(GGML_TYPE_TQ2_0);
29963070
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
29973071
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
29983072
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
1616
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
1717
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
1818
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
19+
type == GGML_TYPE_TQ2_0 ? vec_dot_tq2_0_q8_1 :
1920
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
2021
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
2122
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
@@ -40,6 +41,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
4041
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
4142
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
4243
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
44+
type == GGML_TYPE_TQ2_0 ? VDR_TQ2_0_Q8_1_MMVQ :
4345
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
4446
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
4547
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
@@ -281,6 +283,13 @@ static void mul_mat_vec_q6_K_q8_1_cuda(
281283
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
282284
}
283285

286+
static void mul_mat_vec_tq2_0_q8_1_cuda(
287+
const void * vx, const void * vy, float * dst,
288+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
289+
290+
mul_mat_vec_q_cuda<GGML_TYPE_TQ2_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
291+
}
292+
284293
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
285294
const void * vx, const void * vy, float * dst,
286295
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -398,6 +407,9 @@ void ggml_cuda_op_mul_mat_vec_q(
398407
case GGML_TYPE_Q6_K:
399408
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
400409
break;
410+
case GGML_TYPE_TQ2_0:
411+
mul_mat_vec_tq2_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
412+
break;
401413
case GGML_TYPE_IQ2_XXS:
402414
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
403415
break;

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TYPES_MMQ = [
2424
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
2525
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
26+
"GGML_TYPE_TQ2_0",
2627
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
2728
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_Q6_0"
2829
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_TQ2_0);

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,32 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
555555
return d6 * sumf_d;
556556
}
557557

558+
#define VDR_TQ2_0_Q8_1_MMVQ 2
559+
#define VDR_TQ2_0_Q8_1_MMQ 8
560+
561+
// Can use the same for both mmvq and mmq, because there are no sub-scales in a TQ2_0 block
562+
template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_impl(
563+
const int * __restrict__ v, const int * __restrict__ u, const float & d2, const float * __restrict__ d8) {
564+
565+
float sumf = 0.0f;
566+
567+
#pragma unroll
568+
for (int i0 = 0; i0 < QR2_0; ++i0) {
569+
int sumi = 0;
570+
571+
#pragma unroll
572+
for (int i = 0; i < vdr; ++i) {
573+
const int vi = (v[i] >> (2*i0)) & 0x03030303;
574+
575+
sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product
576+
}
577+
578+
sumf += d8[i0] * sumi;
579+
}
580+
581+
return d2 * sumf;
582+
}
583+
558584
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
559585
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
560586

@@ -837,6 +863,37 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
837863
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
838864
}
839865

866+
static __device__ __forceinline__ float vec_dot_tq2_0_q8_1(
867+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
868+
869+
const block_tq2_0 * btq2_0 = (const block_tq2_0 *) vbq + kbx;
870+
871+
// iqs 0..7 all need bq8_offset 0, 1, 2, 3
872+
// iqs 8..15 all need bq8_offset 4, 5, 6, 7
873+
const int bq8_offset = QR2_0 * (iqs / 8);
874+
875+
int v[VDR_TQ2_0_Q8_1_MMVQ];
876+
int u[QR2_0*VDR_TQ2_0_Q8_1_MMVQ];
877+
float d8[QR2_0];
878+
879+
#pragma unroll
880+
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
881+
v[i] = get_int_b2(btq2_0->qs, iqs + i);
882+
}
883+
884+
#pragma unroll
885+
for (int i0 = 0; i0 < QR2_0; ++i0) {
886+
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i0;
887+
888+
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
889+
u[VDR_TQ2_0_Q8_1_MMVQ*i0 + i] = get_int_b4(bq8i->qs, (iqs % QI8_1) + i);
890+
}
891+
d8[i0] = __low2float(bq8i->ds);
892+
}
893+
894+
return vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMVQ>(v, u, btq2_0->d, d8);
895+
}
896+
840897
#define VDR_IQ2_XXS_Q8_1_MMVQ 2
841898
#define VDR_IQ2_XXS_Q8_1_MMQ 2
842899

0 commit comments

Comments
 (0)