Skip to content

Commit f989fb0

Browse files
ikawrakowIwan Kawrakow
andauthored
Adding IQ1_KT - 1.75 bpw SOTA quants (#616)
* iq1_kt: basics * iq1_kt: CUDA dequantize Testing with LlaMA-3.1-8B-Instruct, we get almost the same PPL as iq2_xxs, so about 0.2 bpw fewer bits for the same quality. * iq1_kt: CUDA MMQ * iq1_kt: CUDA MMVQ * iq1_kt: AVX2 GEMM/GEMV * iq1_kt: convert/repack to q8_0_r8 (AVX2) * iq1_kt: slightly faster GEMV 18.6 t/s -> 19.4 t/s * iq1_kt: NEON GEMM/GEMV Pathetic as usual * iq1_kt: slightly faster NEON - still pathetic * iq1_kt: tiny bit better GEMV on NEON * iq1_kt: convert/repack to q8_0_r8 (NEON) * iq1_kt: very slightly faster convert/repack to q8_0_r8 on NEON * Adding frgotten file * iq1_kt: add to constants.py --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 07673c6 commit f989fb0

File tree

21 files changed

+930
-6
lines changed

21 files changed

+930
-6
lines changed

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
7575
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
7676
{ "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",},
7777
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
78+
{ "IQ1_KT", LLAMA_FTYPE_MOSTLY_IQ1_KT, " 1.75 bpw trellis quantization", },
7879
{ "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", },
7980
{ "IQ2_KL", LLAMA_FTYPE_MOSTLY_IQ2_KL, " 2.69 bpw non-linear quantization", },
8081
{ "IQ3_KS", LLAMA_FTYPE_MOSTLY_IQ3_KS, " 3.19 bpw non-linear quantization", },

ggml/include/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ extern "C" {
436436
GGML_TYPE_IQ4_KT = 155,
437437
GGML_TYPE_IQ3_KS = 156,
438438
GGML_TYPE_IQ2_KL = 157,
439+
GGML_TYPE_IQ1_KT = 158,
439440

440441
GGML_TYPE_Q4_0_R8 = 202,
441442
GGML_TYPE_Q5_0_R4 = 206,
@@ -530,6 +531,7 @@ extern "C" {
530531
GGML_FTYPE_MOSTLY_IQ4_KT = 144, // except 1d tensors
531532
GGML_FTYPE_MOSTLY_IQ3_KS = 145, // except 1d tensors
532533
GGML_FTYPE_MOSTLY_IQ2_KL = 146, // except 1d tensors
534+
GGML_FTYPE_MOSTLY_IQ1_KT = 147, // except 1d tensors
533535
//
534536
GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
535537
GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors

ggml/src/ggml-common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,13 @@ typedef struct {
629629
} block_iq2_ks;
630630
static_assert(sizeof(block_iq2_ks) == sizeof(uint16_t) + QK_K/64 + QK_K/4, "wrong iq2_ks block size/padding");
631631

632+
typedef struct {
633+
uint8_t sh[QK_K/32]; // 4-bit scales + 13th bits for groups of 8
634+
uint8_t ql[QK_K/8]; // low 8 bits for groups of 8
635+
uint8_t qh[QK_K/16]; // high 4 bits for groups of 8
636+
} block_iq1_kt;
637+
static_assert(sizeof(block_iq1_kt) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_kt block size/padding");
638+
632639
typedef struct {
633640
uint8_t scales[QK_K/64];
634641
uint8_t ql[QK_K/4];

ggml/src/ggml-cuda.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3506,6 +3506,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
35063506
case GGML_TYPE_IQ5_KS:
35073507
case GGML_TYPE_IQ2_K:
35083508
case GGML_TYPE_IQ2_KS:
3509+
case GGML_TYPE_IQ1_KT:
35093510
case GGML_TYPE_IQ2_KT:
35103511
case GGML_TYPE_IQ3_KT:
35113512
case GGML_TYPE_IQ4_KT:

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KS> {
571571
static constexpr int qi = QI4_XS;
572572
};
573573

574+
template<>
575+
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_KT> {
576+
static constexpr int qk = QK_K;
577+
static constexpr int qr = QR4_XS;
578+
static constexpr int qi = QI4_XS;
579+
};
580+
574581
template<>
575582
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
576583
static constexpr int qk = QK_K;

ggml/src/ggml-cuda/convert.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,26 @@ float __device__ __forceinline__ trellis_next(uint32_t& val) {
358358
return (float)(h[0]+h[1]);
359359
}
360360

361+
template<typename dst_t>
362+
static __global__ void dequantize_block_iq1_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
363+
364+
int64_t ii = blockIdx.x;
365+
int64_t row = (QK_K * ii) / n_per_row;
366+
const char * cx = (const char *)vx + row * row_size;
367+
float scale = *(const float *)cx;
368+
const block_iq1_kt * x = (const block_iq1_kt *)(cx + sizeof(float));
369+
const int64_t i = ii - (row*n_per_row)/QK_K;
370+
371+
const int64_t tid = threadIdx.x;
372+
const int64_t ib = tid; // 0...31
373+
dst_t * y = yy + ii*QK_K + 8*ib;
374+
uint32_t idx = (x[i].ql[ib] | ((x[i].qh[ib%16] << (8 - 4*(ib/16))) & 0xf00) | ((x[i].sh[ib/4] << (8 - (ib%4))) & 0x1000)) + 4096;
375+
const float dl = scale * iq4k_values[x[i].sh[ib/4] & 0xf];
376+
for (int j = 0; j < 8; ++j) {
377+
y[j] = dl * trellis_next_int(idx);
378+
}
379+
}
380+
361381
template<typename dst_t>
362382
static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
363383

@@ -1505,6 +1525,13 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
15051525
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
15061526
}
15071527

1528+
template<typename dst_t>
1529+
static void dequantize_row_iq1_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1530+
const int64_t k = nrows * n_per_row;
1531+
const int nb = k / QK_K;
1532+
dequantize_block_iq1_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ1_KT, n_per_row));
1533+
}
1534+
15081535
template<typename dst_t>
15091536
static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
15101537
const int64_t k = nrows * n_per_row;
@@ -1888,6 +1915,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
18881915
return dequantize_row_q6_K_cuda;
18891916
case GGML_TYPE_IQ2_XXS:
18901917
return dequantize_row_iq2_xxs_cuda;
1918+
case GGML_TYPE_IQ1_KT:
1919+
return dequantize_row_iq1_kt_cuda;
18911920
case GGML_TYPE_IQ2_KT:
18921921
return dequantize_row_iq2_kt_cuda;
18931922
case GGML_TYPE_IQ3_KT:
@@ -1987,6 +2016,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
19872016
return dequantize_row_q6_K_cuda;
19882017
case GGML_TYPE_IQ2_XXS:
19892018
return dequantize_row_iq2_xxs_cuda;
2019+
case GGML_TYPE_IQ1_KT:
2020+
return dequantize_row_iq1_kt_cuda;
19902021
case GGML_TYPE_IQ2_KT:
19912022
return dequantize_row_iq2_kt_cuda;
19922023
case GGML_TYPE_IQ3_KT:

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,39 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
443443
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
444444
}
445445

446+
__device__ __forceinline__ void vec_dot_iq1_kt_q8_1(
447+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
448+
449+
constexpr uint32_t ka = 0xCBAC1FED;
450+
constexpr uint32_t km = 0x3f3f3f3f;
451+
452+
float scale = *(const float *)vbq;
453+
const block_iq1_kt * bq1 = (const block_iq1_kt *)((const char *)vbq + sizeof(float)) + kbx;
454+
455+
// iqs is 0...28
456+
const int ib32 = iqs/4;
457+
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
458+
const int ls = iq4k_values[bq1->sh[ib32] & 0xf];
459+
const float dl = scale * ls;
460+
int sumi = 0;
461+
for (int j = 0; j < 4; ++j) {
462+
uint32_t val = bq1->ql[4*ib32+j] + 4096 + ((bq1->qh[4*(ib32%4)+j] << (8 - 4*(ib32/4))) & 0xf00) + ((bq1->sh[ib32] << (8 - j)) & 0x1000);
463+
int v4 = 0;
464+
for (int k = 0; k < 4; ++k) {
465+
val *= ka;
466+
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
467+
}
468+
sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
469+
v4 = 0;
470+
for (int k = 0; k < 4; ++k) {
471+
val *= ka;
472+
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
473+
}
474+
sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
475+
}
476+
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
477+
}
478+
446479
__device__ __forceinline__ void vec_dot_iq2_kt_q8_1(
447480
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
448481

@@ -1350,6 +1383,14 @@ void mul_mat_vec_iq4_kt_q8_1_cuda(
13501383
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
13511384
}
13521385

1386+
void mul_mat_vec_iq1_kt_q8_1_cuda(
1387+
const void * vx, const void * vy, float * dst, const char * ids_data,
1388+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
1389+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
1390+
1391+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq1_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
1392+
}
1393+
13531394
void mul_mat_vec_iq2_kt_q8_1_cuda(
13541395
const void * vx, const void * vy, float * dst, const char * ids_data,
13551396
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,

ggml/src/ggml-cuda/iqk_mmvq.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda(
111111
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
112112
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
113113

114+
void mul_mat_vec_iq1_kt_q8_1_cuda(
115+
const void * vx, const void * vy, float * dst, const char * ids_data,
116+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
117+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
118+
114119
void mul_mat_vec_iq2_kt_q8_1_cuda(
115120
const void * vx, const void * vy, float * dst, const char * ids_data,
116121
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ void ggml_cuda_op_mul_mat_q(
109109
case GGML_TYPE_IQ4_KT:
110110
mul_mat_q_case<GGML_TYPE_IQ4_KT>(ctx, args, stream);
111111
break;
112+
case GGML_TYPE_IQ1_KT:
113+
mul_mat_q_case<GGML_TYPE_IQ1_KT>(ctx, args, stream);
114+
break;
112115
case GGML_TYPE_IQ2_KT:
113116
mul_mat_q_case<GGML_TYPE_IQ2_KT>(ctx, args, stream);
114117
break;
@@ -211,6 +214,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
211214
case GGML_TYPE_IQ5_KS:
212215
case GGML_TYPE_IQ5_KS_R4:
213216
case GGML_TYPE_IQ2_KS:
217+
case GGML_TYPE_IQ1_KT:
214218
case GGML_TYPE_IQ2_KT:
215219
case GGML_TYPE_IQ3_KT:
216220
case GGML_TYPE_IQ4_KT:

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
100100
case GGML_TYPE_IQ5_KS:
101101
case GGML_TYPE_IQ5_KS_R4:
102102
case GGML_TYPE_IQ6_K:
103+
case GGML_TYPE_IQ1_KT:
103104
case GGML_TYPE_IQ2_KT:
104105
case GGML_TYPE_IQ3_KT:
105106
case GGML_TYPE_IQ4_KT:
@@ -218,6 +219,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
218219
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
219220
case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16;
220221
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
222+
case GGML_TYPE_IQ1_KT : return MMQ_DP4A_TXS_Q8_0;
221223
case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0;
222224
case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0;
223225
case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0;
@@ -275,6 +277,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
275277
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
276278
case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K;
277279
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
280+
case GGML_TYPE_IQ1_KT : return MMQ_MMA_TILE_X_K_Q8_0;
278281
case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0;
279282
case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0;
280283
case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0;
@@ -4176,9 +4179,10 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4);
41764179
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
41774180
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
41784181
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
4179-
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);
4182+
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_KT);
41804183
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT);
41814184
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT);
4185+
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);
41824186

41834187
// -------------------------------------------------------------------------------------------------------------------------
41844188

0 commit comments

Comments
 (0)