diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7dd023132..3de7bc20a 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -75,6 +75,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, { "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",}, { "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",}, + { "IQ1_KT", LLAMA_FTYPE_MOSTLY_IQ1_KT, " 1.75 bpw trellis quantization", }, { "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", }, { "IQ2_KL", LLAMA_FTYPE_MOSTLY_IQ2_KL, " 2.69 bpw non-linear quantization", }, { "IQ3_KS", LLAMA_FTYPE_MOSTLY_IQ3_KS, " 3.19 bpw non-linear quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 644a82b23..5b90c9a50 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -436,6 +436,7 @@ extern "C" { GGML_TYPE_IQ4_KT = 155, GGML_TYPE_IQ3_KS = 156, GGML_TYPE_IQ2_KL = 157, + GGML_TYPE_IQ1_KT = 158, GGML_TYPE_Q4_0_R8 = 202, GGML_TYPE_Q5_0_R4 = 206, @@ -530,6 +531,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_KT = 144, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_KS = 145, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_KL = 146, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ1_KT = 147, // except 1d tensors // GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 6dc439b8f..1dc1ff6ec 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -629,6 +629,13 @@ typedef struct { } block_iq2_ks; static_assert(sizeof(block_iq2_ks) == sizeof(uint16_t) + QK_K/64 + QK_K/4, "wrong iq2_ks block size/padding"); +typedef struct { + uint8_t sh[QK_K/32]; // 4-bit scales + 13th bits for groups of 8 + uint8_t ql[QK_K/8]; // low 8 bits for groups of 8 + uint8_t qh[QK_K/16]; // high 4 bits for groups of 8 +} block_iq1_kt; +static_assert(sizeof(block_iq1_kt) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_kt block size/padding"); + typedef struct { uint8_t scales[QK_K/64]; uint8_t ql[QK_K/4]; diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index b33c952b5..7fee71d84 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3506,6 +3506,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 38b52fd05..15485f606 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -571,6 +571,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index c8e02a83d..8c03ae1bc 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -358,6 +358,26 @@ float __device__ __forceinline__ trellis_next(uint32_t& val) { return (float)(h[0]+h[1]); } +template +static __global__ void dequantize_block_iq1_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const float *)cx; + const block_iq1_kt * x = (const block_iq1_kt *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + uint32_t idx = (x[i].ql[ib] | ((x[i].qh[ib%16] << (8 - 4*(ib/16))) & 0xf00) | ((x[i].sh[ib/4] << (8 - (ib%4))) & 0x1000)) + 4096; + const float dl = scale * iq4k_values[x[i].sh[ib/4] & 0xf]; + for (int j = 0; j < 8; ++j) { + y[j] = dl * trellis_next_int(idx); + } +} + template static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { @@ -1505,6 +1525,13 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_ dequantize_block_iq2_xxs<<>>(vx, y); } +template +static void dequantize_row_iq1_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq1_kt<<>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ1_KT, n_per_row)); +} + template static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1888,6 +1915,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ1_KT: + return dequantize_row_iq1_kt_cuda; case GGML_TYPE_IQ2_KT: return dequantize_row_iq2_kt_cuda; case GGML_TYPE_IQ3_KT: @@ -1987,6 +2016,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ1_KT: + return dequantize_row_iq1_kt_cuda; case GGML_TYPE_IQ2_KT: return dequantize_row_iq2_kt_cuda; case GGML_TYPE_IQ3_KT: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index a669390dd..c7f5dfb48 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -443,6 +443,39 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +__device__ __forceinline__ void vec_dot_iq1_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq1_kt * bq1 = (const block_iq1_kt *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const int ls = iq4k_values[bq1->sh[ib32] & 0xf]; + const float dl = scale * ls; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + uint32_t val = bq1->ql[4*ib32+j] + 4096 + ((bq1->qh[4*(ib32%4)+j] << (8 - 4*(ib32/4))) & 0xf00) + ((bq1->sh[ib32] << (8 - j)) & 0x1000); + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi); + v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + __device__ __forceinline__ void vec_dot_iq2_kt_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { @@ -1350,6 +1383,14 @@ void mul_mat_vec_iq4_kt_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq1_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq2_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index d14c3541e..5d62d02ee 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -111,6 +111,11 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda( const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); +void mul_mat_vec_iq1_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + void mul_mat_vec_iq2_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index cde5d0448..d417fdc06 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -109,6 +109,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KT: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ1_KT: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ2_KT: mul_mat_q_case(ctx, args, stream); break; @@ -211,6 +214,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 21b50082a..aaf02fab9 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -100,6 +100,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -218,6 +219,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ1_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; @@ -275,6 +277,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ1_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; @@ -4176,9 +4179,10 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); -extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d0746031c..012b3e5e7 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -533,6 +533,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ1_KT: + mul_mat_vec_iq1_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ2_KT: mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; @@ -704,6 +707,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt.cu new file mode 100644 index 000000000..1a3590e5b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt.cu @@ -0,0 +1,81 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +template static __device__ __forceinline__ void load_tiles_iq1_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_kt * bxi = (const block_iq1_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + uint32_t val = bxi->ql[kqsx] + ((bxi->qh[kqsx%16] << (8 - 4*(kqsx/16))) & 0xf00) + ((bxi->sh[kqsx/4] << (8 - (kqsx%4))) & 0x1000) + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0]; + const block_iq1_kt * bxi = (const block_iq1_kt *)(dptr + 1) + kbx0; + const int ls = iq4k_values[bxi->sh[threadIdx.x % 8] & 0xf]; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e18cee73b..e49417af5 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15421,6 +15421,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_Q6_0: break; case GGML_TYPE_IQ2_K: break; case GGML_TYPE_IQ2_KS: break; + case GGML_TYPE_IQ1_KT: break; case GGML_TYPE_IQ2_KT: break; case GGML_TYPE_IQ3_KT: break; case GGML_TYPE_IQ4_KT: break; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b3982538d..5aec6b0db 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1587,6 +1587,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 2, }, + [GGML_TYPE_IQ1_KT] = { + .type_name = "iq1_kt", + .blck_size = QK_K, + .type_size = sizeof(block_iq1_kt), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq1_kt, + .from_float = quantize_row_iq1_kt, + .from_float_ref = (ggml_from_float_t)quantize_row_iq1_kt_ref, + .vec_dot = vec_dot_iq1_kt_q8_k, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif + .nrows = 1, + .row_meta_size = 4, + }, [GGML_TYPE_IQ2_KT] = { .type_name = "iq2_kt", .blck_size = QK_K, @@ -4600,6 +4617,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; case GGML_FTYPE_MOSTLY_IQ2_K_R4: wtype = GGML_TYPE_IQ2_K_R4; break; case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break; + case GGML_FTYPE_MOSTLY_IQ1_KT: wtype = GGML_TYPE_IQ1_KT; break; case GGML_FTYPE_MOSTLY_IQ2_KT: wtype = GGML_TYPE_IQ2_KT; break; case GGML_FTYPE_MOSTLY_IQ3_KT: wtype = GGML_TYPE_IQ3_KT; break; case GGML_FTYPE_MOSTLY_IQ4_KT: wtype = GGML_TYPE_IQ4_KT; break; @@ -11379,6 +11397,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -11858,6 +11877,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -12034,6 +12054,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -15537,6 +15558,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -15953,6 +15975,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -16275,6 +16298,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -16914,6 +16938,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -23989,6 +24014,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K_R4:result = quantize_iq2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ1_KT: result = quantize_iq1_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_KT: result = quantize_iq2_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_KT: result = quantize_iq3_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KT: result = quantize_iq4_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 19c30e2ad..577021992 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -199,6 +199,56 @@ struct Trellis3 { #else auto dot = _mm256_maddubs_epi16(aux[i], _mm256_set1_epi32(0x01010101)); aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif + } + for (int k = 0; k < 4; ++k) { + auto v1 = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]); + auto v2 = _mm256_packs_epi32(aux[4*k+2], aux[4*k+3]); + result[k] = _mm256_permutevar8x32_epi32(_mm256_packs_epi16(v1, v2), shuffle); + } + if constexpr (is_abs) { + for (int k = 0; k < 4; ++k) { + result[k] = _mm256_sign_epi8(result[k], result[k]); + } + } + } + IQK_ALWAYS_INLINE inline void next_128(__m256i val, __m256i * result) const { + // Even though we only have 16 vector registers nn AVX2, this is still faster + __m256i aux[16]; + __m256i tmp[2]; + tmp[0] = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(val)); + tmp[1] = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(val, 1)); + for (int k = 0; k < 2; ++k) { + auto vl = _mm256_castsi256_si128(tmp[k]); + auto v = MM256_SET_M128I(vl, vl); + aux[8*k+0] = _mm256_shuffle_epi32(v, 0x00); + aux[8*k+1] = _mm256_shuffle_epi32(v, 0x55); + aux[8*k+2] = _mm256_shuffle_epi32(v, 0xaa); + aux[8*k+3] = _mm256_shuffle_epi32(v, 0xff); + auto vh = _mm256_extracti128_si256(tmp[k], 1); + v = MM256_SET_M128I(vh, vh); + aux[8*k+4] = _mm256_shuffle_epi32(v, 0x00); + aux[8*k+5] = _mm256_shuffle_epi32(v, 0x55); + aux[8*k+6] = _mm256_shuffle_epi32(v, 0xaa); + aux[8*k+7] = _mm256_shuffle_epi32(v, 0xff); + } + for (int i = 0; i < 16; ++i) { + aux[i] = _mm256_mullo_epi32(aux[i], mka); + } + auto mask = _mm256_set1_epi32(0x3f3f3f3f); + for (int i = 0; i < 16; ++i) { + aux[i] = _mm256_and_si256(aux[i], mask); + } + auto offset = _mm256_set1_epi32(-126); +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) + auto m1 = _mm256_set1_epi32(0x01010101); +#endif + for (int i = 0; i < 16; ++i) { +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) + aux[i] = _mm256_dpbusd_epi32(offset, aux[i], m1); +#else + auto dot = _mm256_maddubs_epi16(aux[i], _mm256_set1_epi32(0x01010101)); + aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); #endif } for (int k = 0; k < 4; ++k) { @@ -463,6 +513,148 @@ void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +void iqk_dequantize_iq1_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq1_kt * x8[8]; + float dkt[8]; + float ls[8]; + float ls_all[64]; + uint32_t idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq1_kt *)(dptr + 1); + } + auto vd = _mm256_loadu_ps(dkt); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto sh = _mm_loadl_epi64((const __m128i *)x8[k][i].sh); + auto s8 = _mm_shuffle_epi8(values, _mm_and_si128(sh, _mm_set1_epi8(0xf))); + auto s32 = _mm256_cvtepi8_epi32(s8); + _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32)); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls)); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + for (int j = 0; j < 4; ++j) { + int jj = 4*ib + j; + for (int k = 0; k < 8; ++k) { + idx[k] = (x8[k][i].ql[jj] | ((x8[k][i].qh[jj%16] << (8 - 4*(jj/16))) & 0xf00) | ((x8[k][i].sh[jj/4] << (8 - (jj%4))) & 0x1000)) + 4096; + } + __m256i packed[2]; + trellis.next64(idx, packed); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]); + } + } + y += 8; // = QK_K/32; + } + } +} + +template +void mul_mat_iq1_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + __m256i xv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); +#else + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif + } + }; + + __m256i idx[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0]); + const block_iq1_kt * x = (const block_iq1_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto sh = _mm_loadl_epi64((const __m128i *)x[i].sh); + auto s32 = _mm256_cvtepi8_epi32(_mm_shuffle_epi8(values, _mm_and_si128(sh, _mm_set1_epi8(0xf)))); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + auto qs8l = _mm_loadu_si128((const __m128i *)x[i].ql+0); + auto qs8h = _mm_loadu_si128((const __m128i *)x[i].ql+1); + auto qh16 = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[i].qh)); + idx[0] = _mm256_or_si256(_mm256_cvtepu8_epi16(qs8l), _mm256_and_si256(_mm256_set1_epi16(0xf00), _mm256_slli_epi16(qh16, 8))); + idx[1] = _mm256_or_si256(_mm256_cvtepu8_epi16(qs8h), _mm256_and_si256(_mm256_set1_epi16(0xf00), _mm256_slli_epi16(qh16, 4))); + idx[0] = _mm256_add_epi16(idx[0], _mm256_set1_epi16(4096)); + idx[1] = _mm256_add_epi16(idx[1], _mm256_set1_epi16(4096)); + auto sh32 = _mm256_and_si256(_mm256_cvtepu8_epi32(sh), _mm256_set1_epi32(0xf0)); + sh32 = _mm256_and_si256(_mm256_mullo_epi32(sh32, _mm256_set1_epi32(0x01020408)), _mm256_set1_epi8(-128)); + idx[0] = _mm256_add_epi16(idx[0], _mm256_slli_epi16(_mm256_cvtepu8_epi16(_mm256_castsi256_si128(sh32)), 5)); + idx[1] = _mm256_add_epi16(idx[1], _mm256_slli_epi16(_mm256_cvtepu8_epi16(_mm256_extracti128_si256(sh32, 1)), 5)); + for (int i128 = 0; i128 < 2; ++i128) { + trellis.next_128(idx[i128], xv); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16)); + auto dy8 = _mm256_mul_ps(scales[i128], _mm256_set_m128(dy4, dy4)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(dy8, sum_4(), accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + template void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -1091,11 +1283,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array; + func16 = mul_mat_iq1_kt_q8_2_x4_T<16>; #endif return true; } @@ -1124,6 +1316,17 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array; +#endif + return true; + } + return false; + } + if (ggml_type(typeB) != GGML_TYPE_F32) { return false; } @@ -1148,6 +1351,7 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array +void mul_mat_iq1_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = vld1q_s8(iq4k_values); + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + + float32x4_t accd[k_acc]; + + const block_q8_0_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); + } + + int8x16x2_t xv[8]; + uint16x8x4_t idx; + int32x4x4_t dot; + + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { + for (int k = 0; k < 4; ++k) { + auto yv = vld1q_s8_x2(y + 32*k); + dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); + } + dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]); + dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]); + return vpaddq_s32(dot.val[0], dot.val[2]); + }; + + float32x4x2_t scales; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = vdupq_n_f32(dptr[0]); + const block_iq1_kt * x = (const block_iq1_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); + + for (int i = 0; i < nb; ++i) { + auto sh = vld1_u8(x[i].sh); + auto s16 = vmovl_s8(vqtbl1_s8(values, vand_u8(sh, vdup_n_u8(0xf)))); + scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16)))); + scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + auto ql = vld1q_u8_x2(x[i].ql); + auto qh = vld1q_u8(x[i].qh); + auto qhl = vmovl_u8(vget_low_u8(qh)); + auto qhh = vmovl_u8(vget_high_u8(qh)); + idx.val[0] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 8))); + idx.val[1] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 8))); + idx.val[2] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 4))); + idx.val[3] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 4))); + for (int k = 0; k < 4; ++k) idx.val[k] = vaddq_u16(idx.val[k], vdupq_n_u16(4096)); + auto sh16 = vandq_u16(vmovl_u8(sh), vdupq_n_u16(0xf0)); + auto sh32l = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_low_u16 (sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80)); + auto sh32h = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_high_u16(sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80)); + idx.val[0] = vaddq_u16(idx.val[0], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32l)), 5)); + idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5)); + idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5)); + idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5)); + if constexpr (nrc_y == 1) { + const block_q8_0_x4& ybl = y[0][2*i+0]; + const block_q8_0_x4& ybh = y[0][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + int32x4x4_t suml = {}; + int32x4x4_t sumh = {}; + for (int ib = 0; ib < 2; ++ib) { + auto xl = trellis.next32(vget_low_u16(idx.val[ib+0])); + auto xh = trellis.next32(vget_low_u16(idx.val[ib+2])); + auto yl = vld1q_s8_x2(ybl.qs + 64*ib); + auto yh = vld1q_s8_x2(ybh.qs + 64*ib); + suml.val[2*ib+0] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]); + sumh.val[2*ib+0] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]); + xl = trellis.next32(vget_high_u16(idx.val[ib+0])); + xh = trellis.next32(vget_high_u16(idx.val[ib+2])); + yl = vld1q_s8_x2(ybl.qs + 64*ib + 32); + yh = vld1q_s8_x2(ybh.qs + 64*ib + 32); + suml.val[2*ib+1] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]); + sumh.val[2*ib+1] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]); + } + auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]); + auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]); + auto sl = vpaddq_s32(sl1, sl2); + auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]); + auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]); + auto sh = vpaddq_s32(sh1, sh2); + accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl)); + accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh)); + } else { + for (int k = 0; k < 4; ++k) { + xv[2*k+0] = trellis.next32(vget_low_u16 (idx.val[k])); + xv[2*k+1] = trellis.next32(vget_high_u16(idx.val[k])); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); + } + } + } + } + + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } + } + } +} + template void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -2284,6 +2699,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ1_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ2_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_KT : return nrc_y >= 24 ? GGML_TYPE_Q8_0_R8 : type; @@ -293,6 +294,7 @@ struct MulMat { case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ1_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; @@ -442,6 +444,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, //case GGML_TYPE_IQ4_KS_R4: //case GGML_TYPE_IQ5_KS_R4: return iqk_convert_iqk_quants_q80_r8(typeA, n, vx, bx, vy, nrc_x); + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -848,6 +851,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_KS_R4: return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16); + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -961,6 +965,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16); + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index b38cc51f7..1dfb5218e 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -8552,8 +8552,273 @@ std::vector QuantizerIQKT; + +const QuantizerIQ1KT& iq1kt_quantizer() { + static std::mutex mutex; + static std::unique_ptr quantizer; + std::lock_guard lock(mutex); + if (!quantizer) quantizer = std::make_unique(256, 32); + return *quantizer; +} + +void quantize_row_iq1_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights, + int * all_idx) { + + constexpr float kSigmaScale = 2.0f; + using Q = QuantizerIQ1KT; + + static_assert(Q::kNumVal%8 == 0); + + float * dptr = (float *)vy; + + block_iq1_kt * y = (block_iq1_kt *)(dptr + 1); + + int best_idx[2*Q::kNg]; + + auto& quantizer = iq1kt_quantizer(); + + int nblock = n_per_row / Q::kSuperBlockSize; + + Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + + float amax_row = 0; + for (int j = 0; j < n_per_row; ++j) { + amax_row = std::max(amax_row, std::abs(x[j])); + } + + float amax_scale = 0, max_scale = 0; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq1_kt)); + + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + float amax = 0; + for (int j = 0; j < Q::kBlockSize; ++j) { + float ax = std::abs(xb[j]); + amax = std::max(amax, ax); + } + float scale_0 = std::max(90.f, 124.f*amax/amax_row); + quantizer.find_best_match( amax/scale_0, xb, weight, best_idx); + auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); + quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx + Q::kNg); + auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg); + + auto idx = best_idx; + if (score_p > score_m) scales[ib] = dp; + else { + scales[ib] = dm; idx += Q::kNg; score_p = score_m; + } + for (int ig = 0; ig < Q::kNg; ++ig) all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig] = idx[ig]; + + scale_0 -= 8; + quantizer.find_best_match( amax/scale_0, xb, weight, best_idx); + auto [dp1, score_p1] = quantizer.find_best_scale(xb, weight, best_idx); + quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx + Q::kNg); + auto [dm1, score_m1] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg); + + if (score_p1 > score_p || score_m1 > score_p) { + idx = best_idx; + if (score_p1 > score_m1) scales[ib] = dp1; + else { + scales[ib] = dm1; idx += Q::kNg; + } + for (int ig = 0; ig < Q::kNg; ++ig) all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig] = idx[ig]; + } + + float abs_scale = std::abs(scales[ib]); + if (abs_scale > amax_scale) { + amax_scale = abs_scale; + max_scale = scales[ib]; + } + } + + } + + if (!max_scale) { + *dptr = 0; + return; + } + + float d = max_scale/iq4k_values[0]; + float best = 0; + for (int itry = -9; itry <= 9; ++itry) { + float id = (itry + iq4k_values[0])/max_scale; + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * xb = x + ibl*Q::kSuperBlockSize; + const float * wb = all_weights + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int ls = best_index_iq4nl(iq4k_values, id*scales[ib]); + float dl = iq4k_values[ls]; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto qb = quantizer.values() + Q::kGroupSize*all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig]; + for (int j = 0; j < Q::kGroupSize; ++j) { + int jj = ig*Q::kGroupSize + j; + float q = dl*qb[j]; + sumqx += wb[jj]*xb[jj]*q; + sumq2 += wb[jj]*q*q; + } + } + xb += Q::kBlockSize; + wb += Q::kBlockSize; + } + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; + } + } + + float id = d ? 1/d : 0.f; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int ls = best_index_iq4nl(iq4k_values, id*scales[ib]); + y[ibl].sh[ib] = ls; + } + } + + *dptr = d; + if (!d) return; + + for (int iloop = 0; iloop < 1; ++iloop) { + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + const float * xbl = x + ibl*Q::kSuperBlockSize; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + int ls = iq4k_values[y[ibl].sh[ib] & 0xf]; + float dl = d*ls; + quantizer.find_best_match(dl, xb, weight, best_idx); + + auto prev_idx = all_idx + (ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize; + + float mse1 = 0, mse2 = 0; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto q1 = quantizer.values() + Q::kGroupSize*prev_idx[ig]; + auto q2 = quantizer.values() + Q::kGroupSize*best_idx[ig]; + for (int j = 0; j < Q::kGroupSize; ++j) { + int jj = ig*Q::kGroupSize + j; + float diff1 = xb[jj] - dl*q1[j]; + float diff2 = xb[jj] - dl*q2[j]; + mse1 += weight[jj]*diff1*diff1; + mse2 += weight[jj]*diff2*diff2; + } + } + if (mse1 < mse2) { + for (int ig = 0; ig < Q::kNg; ++ig) best_idx[ig] = prev_idx[ig]; + } else { + for (int ig = 0; ig < Q::kNg; ++ig) prev_idx[ig] = best_idx[ig]; + } + + for (int j = 0; j < Q::kNg; ++j) { + y[ibl].ql[ib*Q::kNg+j] = best_idx[j] & 0xff; + y[ibl].qh[(ib%(Q::kNblock/2))*Q::kNg+j] |= (((best_idx[j] >> 8) & 0xf) << 4*(ib/(Q::kNblock/2))); + y[ibl].sh[ib] |= ((best_idx[j] >> 12) << (4+j)); + auto xl = xb + Q::kGroupSize*j; + auto wl = weight + Q::kGroupSize*j; + auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize; + for (int k = 0; k < Q::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + *dptr = d * 1.07f; + if (!d) return; + } else { + break; + } + + } + +} +} + +void quantize_row_iq1_kt_ref(const float * GGML_RESTRICT x, block_iq1_kt * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq1_kt(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq1_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq1_kt * y = (block_iq1_kt *)vy; + quantize_row_iq1_kt_ref(x, y, k); +} + +size_t quantize_iq1_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ1_KT, n_per_row); + std::vector scales(n_per_row/QuantizerIQ1KT::kBlockSize); + std::vector weights(n_per_row); + std::vector idx(n_per_row/QuantizerIQ1KT::kGroupSize); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq1_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), idx.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq1_kt(const block_iq1_kt * x, float * y, int64_t k) { + assert(k % QuantizerIQ1KT::kSuperBlockSize == 0); + using Q = QuantizerIQ1KT; + const int nb = k / Q::kSuperBlockSize; + const float * dptr = (const float *)x; + const float d = *dptr * Q::kScale; + x = (const block_iq1_kt *)(dptr + 1); + auto& deq = iq1kt_quantizer(); + for (int ibl = 0; ibl < nb; ++ibl) { + for (int ib = 0; ib < Q::kNblock; ++ib) { + float sl = d * iq4k_values[x[ibl].sh[ib] & 0xf]; + for (int ig = 0; ig < Q::kNg; ++ig) { + uint16_t idx = x[ibl].ql[ib*Q::kNg + ig] | ((x[ibl].qh[(ib%(Q::kNblock/2))*Q::kNg + ig] << (8 - 4*(ib/(Q::kNblock/2)))) & 0xf00); + idx |= (x[ibl].sh[ib] << (8 - ig) & 0x1000); + deq.set_values(idx, y, sl); + y += Q::kGroupSize; + } + } + } +} + +void vec_dot_iq1_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + +} + // ========================================== iq2_kt ==================================================== +namespace { + using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16, false, true>; const QuantizerIQ2KT& iq2kt_quantizer() { diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 75fa9b4e2..7d789fbaf 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -79,6 +79,12 @@ size_t quantize_iq2_kl(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst void dequantize_row_iq2_kl(const block_iq2_kl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq2_kl_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq1_kt_ref(const float * GGML_RESTRICT x, block_iq1_kt * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq1_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq1_kt(const block_iq1_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq1_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_iq2_kt_ref(const float * GGML_RESTRICT x, block_iq2_kt * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq2_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 767637c5f..32a667e26 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1322,6 +1322,7 @@ class GGMLQuantizationType(IntEnum): IQ4_KT = 155 IQ3_KS = 156 IQ2_KL = 157 + IQ1_KT = 158 Q4_0_R8 = 202 Q5_0_R4 = 206 Q8_0_R8 = 208 @@ -1539,6 +1540,7 @@ def get_type(val: Any) -> GGUFValueType: GGMLQuantizationType.IQ4_KT : ( 256, 128), GGMLQuantizationType.IQ3_KS : ( 256, 102), GGMLQuantizationType.IQ2_KL : ( 256, 86), + GGMLQuantizationType.IQ1_KT : ( 256, 56), GGMLQuantizationType.Q4_0_R8 : ( 32, 18), GGMLQuantizationType.Q5_0_R4 : ( 32, 22), GGMLQuantizationType.Q8_0_R8 : ( 32, 34), diff --git a/include/llama.h b/include/llama.h index bcd81f4f9..1bc1bdafc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -206,6 +206,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_KT = 153, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_KS = 154, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_KL = 155, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ1_KT = 156, // except 1d tensors // LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 58812fc8a..61e7ed51b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4411,6 +4411,7 @@ struct llama_model_loader { case GGML_TYPE_IQ2_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_M_R4;break; case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break; + case GGML_TYPE_IQ1_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ1_KT; break; case GGML_TYPE_IQ2_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT; break; case GGML_TYPE_IQ3_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT; break; case GGML_TYPE_IQ4_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT; break; @@ -5156,6 +5157,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ2_M_R4: return "IQ2_M_R4 - 2.7 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_KT: return "IQ1_KT - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_KT: return "IQ2_KT - 2.125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_KT: return "IQ3_KT - 3.125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KT: return "IQ4_KT - 4.0 bpw"; @@ -19152,7 +19154,8 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4|| new_type == GGML_TYPE_IQ3_KS || new_type == GGML_TYPE_IQ2_KT || new_type == GGML_TYPE_IQ3_KT || new_type == GGML_TYPE_IQ4_KT || - new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4|| new_type == GGML_TYPE_IQ2_KL) { + new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4|| new_type == GGML_TYPE_IQ2_KL || + new_type == GGML_TYPE_IQ1_KT) { if (nx % QK_K != 0) { LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); convert_incompatible_tensor = true; @@ -19192,6 +19195,7 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -19324,7 +19328,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_KL || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) { + ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ1_KT) { new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4) { @@ -19918,6 +19922,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS_R4:default_type = GGML_TYPE_IQ2_XS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ2_KS: default_type = GGML_TYPE_IQ2_KS; break; + case LLAMA_FTYPE_MOSTLY_IQ1_KT: default_type = GGML_TYPE_IQ1_KT; break; case LLAMA_FTYPE_MOSTLY_IQ2_KT: default_type = GGML_TYPE_IQ2_KT; break; case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;