Skip to content

Commit 462c6cd

Browse files
ikawrakowIwan Kawrakow
andauthored
Enable q6_0 for flash attention (#101)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent dbf951d commit 462c6cd

11 files changed

+120
-16
lines changed

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@ else
600600
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
601601
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu))
602602
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:iq4_nl-iq4_nl.cu))
603+
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:q6_0-q5_0.cu))
604+
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:q8_0-q6_0.cu))
603605
endif # GGML_CUDA_FA_ALL_QUANTS
604606

605607
ifdef GGML_CUDA

ggml/src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ if (GGML_CUDA)
332332
list(APPEND GGML_SOURCES_CUDA ${SRCS})
333333
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*iq4_nl-iq4_nl.cu")
334334
list(APPEND GGML_SOURCES_CUDA ${SRCS})
335+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q6_0-q5_0.cu")
336+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
337+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q6_0.cu")
338+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
335339
endif()
336340

337341
list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA)

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
277277
return sum;
278278
}
279279

280+
template<typename T, int D>
281+
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0(
282+
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
283+
284+
const block_q6_0 * K_q6_0 = (const block_q6_0 *) K_c;
285+
GGML_UNUSED(Q_v);
286+
287+
T sum = 0.0f;
288+
289+
#pragma unroll
290+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
291+
const int k_KQ = k_KQ_0 + threadIdx.x;
292+
293+
const int ib = k_KQ / QI8_1;
294+
const int iqs4 = k_KQ % QI6_0; // 0...3
295+
const int shift = k_KQ & (QI8_1/2);
296+
297+
const int vh = (get_int_b2(K_q6_0[ib].qh, iqs4%2) >> (4*(iqs4/2) + shift/2)) & 0x03030303;
298+
const int vl = (get_int_b2(K_q6_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
299+
const int v = vl | (vh << 4);
300+
301+
const int u = Q_q8[k_KQ_0/WARP_SIZE];
302+
303+
const int sumi = ggml_cuda_dp4a(v, u, 0);
304+
305+
#ifdef FP16_AVAILABLE
306+
if (std::is_same<T, half>::value) {
307+
const half2 * Q_ds = (const half2 *) Q_ds_v;
308+
309+
const half2 sum2 = __half2half2(K_q6_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
310+
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(4.0f)) /* *32/QI8_1 == 4 */;
311+
} else
312+
#endif // FP16_AVAILABLE
313+
{
314+
const float2 * Q_ds = (const float2 *) Q_ds_v;
315+
316+
sum += (T) (__half2float(K_q6_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (32/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
317+
}
318+
}
319+
320+
return sum;
321+
}
322+
280323
template <typename T, int D>
281324
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
282325
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -510,6 +553,30 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
510553
return __low2float(dm)*((float) q) + __high2float(dm);
511554
}
512555

556+
template <typename T>
557+
static __device__ __forceinline__ T dequantize_1_q6_0(const void * __restrict__ vx, const int64_t i) {
558+
const block_q6_0 * x = (const block_q6_0 *) vx;
559+
560+
const int64_t ib = i / QK6_0;
561+
const int idq = i % QK6_0;
562+
const int iqs = i % (QK6_0/2);
563+
const int shift = idq / (QK6_0/2);
564+
//const int shift = (i % QK6_0) / (QK6_0/2);
565+
566+
const T d = x[ib].d;
567+
const int ql = x[ib].qs[iqs] >> 4*shift;
568+
const int qh = x[ib].qh[idq%(QK6_0/4)] >> (4*((idq/(QK6_0/4))%2) + 2*shift);
569+
const int q = ((ql & 0x0f) | ((qh & 0x03) << 4)) - 32;
570+
571+
#ifdef FP16_AVAILABLE
572+
if (std::is_same<T, half>::value) {
573+
return ((half) d)*((half) q);
574+
}
575+
#endif // FP16_AVAILABLE
576+
577+
return ((float) d)*((float) q);
578+
}
579+
513580
template <typename T>
514581
static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
515582
const block_q8_0 * x = (const block_q8_0 *) vx;
@@ -543,6 +610,7 @@ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
543610
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
544611
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
545612
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
613+
type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, D> :
546614
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
547615
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
548616
nullptr;
@@ -555,6 +623,7 @@ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
555623
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, D> :
556624
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
557625
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
626+
type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float, D> :
558627
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
559628
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
560629
nullptr;
@@ -565,6 +634,7 @@ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
565634
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
566635
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
567636
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
637+
type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0<half> :
568638
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
569639
type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<half> :
570640
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
@@ -576,6 +646,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
576646
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
577647
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
578648
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
649+
type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0<float> :
579650
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
580651
type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<float> :
581652
type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
@@ -635,11 +706,13 @@ static void on_no_fattn_vec_case(const int D) {
635706
} else if (D == 128) {
636707
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
637708
fprintf(stderr, "Supported combinations:\n");
638-
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
639-
fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n");
640-
fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n");
641-
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
642-
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
709+
fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n");
710+
fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n");
711+
fprintf(stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n");
712+
fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n");
713+
fprintf(stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n");
714+
fprintf(stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n");
715+
fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n");
643716
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n");
644717
GGML_ABORT("fatal error");
645718
} else {

ggml/src/ggml-cuda/fattn.cu

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,11 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
208208

209209
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
210210

211-
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
212-
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
213-
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
214-
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
215211
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
216212
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
217-
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
213+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
214+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0)
215+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
218216
#else
219217
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
220218

@@ -224,13 +222,10 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
224222
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
225223
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
226224

227-
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
228-
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
229-
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
230-
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
231225
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
232226
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
233-
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
227+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
228+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
234229
#endif // GGML_CUDA_FA_ALL_QUANTS
235230

236231
on_no_fattn_vec_case(Q->ne[0]);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0);

0 commit comments

Comments
 (0)