Skip to content

Commit dbf951d

Browse files
ikawrakowIwan Kawrakow
andauthored
Enable IQ4_NL for KV-cache in token generation using Flash Attention (#99)
* Enable IQ4_NL for V-cache in token generation * We don't need these * Update printour of allowed quantized KV-cache combinations * Add IQ4_NL + IQ4_NL to FA This is a better alternative than Q4_0 + Q4_0 for the VRAM poor. * Remove file added by mistake * Fix typo, which is not really a bug --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent f2d315b commit dbf951d

22 files changed

+214
-37
lines changed

Makefile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,11 @@ endif
246246
# Compile flags
247247
#
248248

249-
# keep standard at C11 and C++11
249+
# keep standard at C11 and C++17
250250
MK_CPPFLAGS = -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon
251251
MK_CFLAGS = -std=c11 -fPIC
252252
MK_CXXFLAGS = -std=c++17 -fPIC
253-
MK_NVCCFLAGS = -std=c++11
253+
MK_NVCCFLAGS = -std=c++17
254254

255255
ifdef LLAMA_NO_CCACHE
256256
GGML_NO_CCACHE := 1
@@ -598,6 +598,8 @@ else
598598
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu))
599599
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu))
600600
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
601+
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu))
602+
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:iq4_nl-iq4_nl.cu))
601603
endif # GGML_CUDA_FA_ALL_QUANTS
602604

603605
ifdef GGML_CUDA

ggml/src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,10 @@ if (GGML_CUDA)
328328
list(APPEND GGML_SOURCES_CUDA ${SRCS})
329329
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
330330
list(APPEND GGML_SOURCES_CUDA ${SRCS})
331+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu")
332+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
333+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*iq4_nl-iq4_nl.cu")
334+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
331335
endif()
332336

333337
list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA)

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 100 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
136136
return sum;
137137
}
138138

139+
static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4) {
140+
const uint8_t * q0_8 = (const uint8_t *) &q4;
141+
const char4 val0_8 = make_char4(kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
142+
return *((const int *) &val0_8);
143+
}
144+
145+
template<typename T, int D>
146+
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl(
147+
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
148+
149+
const block_iq4_nl * K_iq4_nl = (const block_iq4_nl *) K_c;
150+
GGML_UNUSED(Q_v);
151+
152+
T sum = 0.0f;
153+
154+
#pragma unroll
155+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
156+
const int k_KQ = k_KQ_0 + threadIdx.x;
157+
158+
const int ib = k_KQ / QI8_1;
159+
const int iqs4 = k_KQ % QI4_NL;
160+
const int shift = k_KQ & (QI8_1/2);
161+
162+
const int v = get_one_int_from_table_16((get_int_b2(K_iq4_nl[ib].qs, iqs4) >> shift) & 0x0F0F0F0F);
163+
const int u = Q_q8[k_KQ_0/WARP_SIZE];
164+
165+
const int sumi = ggml_cuda_dp4a(v, u, 0);
166+
167+
#ifdef FP16_AVAILABLE
168+
if (std::is_same<T, half>::value) {
169+
const half2 * Q_ds = (const half2 *) Q_ds_v;
170+
sum += (T) (((half)sumi) * K_iq4_nl[ib].d * Q_ds[k_KQ_0/WARP_SIZE].x);
171+
} else
172+
#endif // FP16_AVAILABLE
173+
{
174+
const float2 * Q_ds = (const float2 *) Q_ds_v;
175+
sum += (T) ((float)sumi * __half2float(K_iq4_nl[ib].d) * Q_ds[k_KQ_0/WARP_SIZE].x);
176+
}
177+
}
178+
179+
return sum;
180+
}
181+
139182
template<typename T, int D>
140183
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141184
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -377,6 +420,25 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
377420
return ((float) d)*((float) q);
378421
}
379422

423+
template <typename T>
424+
static __device__ __forceinline__ T dequantize_1_iq4_nl(const void * __restrict__ vx, const int64_t i) {
425+
const block_iq4_nl * x = (const block_iq4_nl *) vx;
426+
427+
const int64_t ib = i / QK4_NL;
428+
const int iqs = i % (QK4_NL/2);
429+
const int shift = (i % QK4_NL) / (QK4_NL/2);
430+
431+
#ifdef FP16_AVAILABLE
432+
if constexpr (std::is_same<T, half>::value) {
433+
return x[ib].d * ((half) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]);
434+
} else {
435+
return (float)x[ib].d * ((float) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]);
436+
}
437+
#endif
438+
T result = (float)x[ib].d * ((float) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]);
439+
return result;
440+
}
441+
380442
template <typename T>
381443
static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
382444
const block_q4_1 * x = (const block_q4_1 *) vx;
@@ -476,44 +538,48 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
476538

477539
template <int D>
478540
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
479-
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
480-
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
481-
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
482-
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
483-
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
484-
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
485-
nullptr;
541+
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
542+
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
543+
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
544+
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
545+
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
546+
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
547+
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
548+
nullptr;
486549
}
487550

488551
template <int D>
489552
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
490-
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
491-
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
492-
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
493-
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
494-
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
495-
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
496-
nullptr;
553+
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
554+
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
555+
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, D> :
556+
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
557+
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
558+
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
559+
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
560+
nullptr;
497561
}
498562

499563
constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
500-
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
501-
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
502-
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
503-
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
504-
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
505-
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
506-
nullptr;
564+
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
565+
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
566+
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
567+
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
568+
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
569+
type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<half> :
570+
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
571+
nullptr;
507572
}
508573

509574
constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
510-
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
511-
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
512-
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
513-
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
514-
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
515-
type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
516-
nullptr;
575+
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
576+
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
577+
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
578+
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
579+
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
580+
type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<float> :
581+
type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
582+
nullptr;
517583
}
518584

519585
template<int D, int parallel_blocks> // D == head size
@@ -569,10 +635,12 @@ static void on_no_fattn_vec_case(const int D) {
569635
} else if (D == 128) {
570636
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
571637
fprintf(stderr, "Supported combinations:\n");
572-
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
573-
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
574-
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
575-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
638+
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
639+
fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n");
640+
fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n");
641+
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
642+
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
643+
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n");
576644
GGML_ABORT("fatal error");
577645
} else {
578646
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,13 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
392392
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
393393
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
394394

395+
//extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
396+
//extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
397+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
398+
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
399+
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
400+
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
401+
395402
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
396403
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
397404
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
152152
} \
153153

154154
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155-
ggml_tensor * Q = dst->src[1];
155+
ggml_tensor * Q = dst->src[0];
156156
ggml_tensor * K = dst->src[1];
157157
ggml_tensor * V = dst->src[2];
158158

@@ -207,6 +207,14 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
207207
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
208208

209209
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
210+
211+
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
212+
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
213+
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
214+
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
215+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
216+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
217+
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
210218
#else
211219
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
212220

@@ -215,6 +223,14 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
215223
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
216224
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
217225
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
226+
227+
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
228+
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
229+
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
230+
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
231+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
232+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
233+
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
218234
#endif // GGML_CUDA_FA_ALL_QUANTS
219235

220236
on_no_fattn_vec_case(Q->ne[0]);
@@ -227,7 +243,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
227243
} \
228244

229245
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230-
ggml_tensor * Q = dst->src[1];
246+
ggml_tensor * Q = dst->src[0];
231247
ggml_tensor * K = dst->src[1];
232248
ggml_tensor * V = dst->src[2];
233249

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL);

0 commit comments

Comments
 (0)