@@ -277,6 +277,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
277
277
return sum;
278
278
}
279
279
280
+ template <typename T, int D>
281
+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0 (
282
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
283
+
284
+ const block_q6_0 * K_q6_0 = (const block_q6_0 *) K_c;
285
+ GGML_UNUSED (Q_v);
286
+
287
+ T sum = 0 .0f ;
288
+
289
+ #pragma unroll
290
+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE) {
291
+ const int k_KQ = k_KQ_0 + threadIdx .x ;
292
+
293
+ const int ib = k_KQ / QI8_1;
294
+ const int iqs4 = k_KQ % QI6_0; // 0...3
295
+ const int shift = k_KQ & (QI8_1/2 );
296
+
297
+ const int vh = (get_int_b2 (K_q6_0[ib].qh , iqs4%2 ) >> (4 *(iqs4/2 ) + shift/2 )) & 0x03030303 ;
298
+ const int vl = (get_int_b2 (K_q6_0[ib].qs , iqs4) >> shift) & 0x0F0F0F0F ;
299
+ const int v = vl | (vh << 4 );
300
+
301
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
302
+
303
+ const int sumi = ggml_cuda_dp4a (v, u, 0 );
304
+
305
+ #ifdef FP16_AVAILABLE
306
+ if (std::is_same<T, half>::value) {
307
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
308
+
309
+ const half2 sum2 = __half2half2 (K_q6_0[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE];
310
+ sum += (T) (((half) sumi)*__low2half (sum2) - __high2half (sum2)*__float2half (4 .0f )) /* *32/QI8_1 == 4 */ ;
311
+ } else
312
+ #endif // FP16_AVAILABLE
313
+ {
314
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
315
+
316
+ sum += (T) (__half2float (K_q6_0[ib].d ) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (32 /QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y ));
317
+ }
318
+ }
319
+
320
+ return sum;
321
+ }
322
+
280
323
template <typename T, int D>
281
324
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0 (
282
325
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -510,6 +553,30 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
510
553
return __low2float (dm)*((float ) q) + __high2float (dm);
511
554
}
512
555
556
+ template <typename T>
557
+ static __device__ __forceinline__ T dequantize_1_q6_0 (const void * __restrict__ vx, const int64_t i) {
558
+ const block_q6_0 * x = (const block_q6_0 *) vx;
559
+
560
+ const int64_t ib = i / QK6_0;
561
+ const int idq = i % QK6_0;
562
+ const int iqs = i % (QK6_0/2 );
563
+ const int shift = idq / (QK6_0/2 );
564
+ // const int shift = (i % QK6_0) / (QK6_0/2);
565
+
566
+ const T d = x[ib].d ;
567
+ const int ql = x[ib].qs [iqs] >> 4 *shift;
568
+ const int qh = x[ib].qh [idq%(QK6_0/4 )] >> (4 *((idq/(QK6_0/4 ))%2 ) + 2 *shift);
569
+ const int q = ((ql & 0x0f ) | ((qh & 0x03 ) << 4 )) - 32 ;
570
+
571
+ #ifdef FP16_AVAILABLE
572
+ if (std::is_same<T, half>::value) {
573
+ return ((half) d)*((half) q);
574
+ }
575
+ #endif // FP16_AVAILABLE
576
+
577
+ return ((float ) d)*((float ) q);
578
+ }
579
+
513
580
template <typename T>
514
581
static __device__ __forceinline__ T dequantize_1_q8_0 (const void * __restrict__ vx, const int64_t i) {
515
582
const block_q8_0 * x = (const block_q8_0 *) vx;
@@ -543,6 +610,7 @@ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
543
610
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
544
611
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
545
612
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
613
+ type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, D> :
546
614
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
547
615
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
548
616
nullptr ;
@@ -555,6 +623,7 @@ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
555
623
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float , D> :
556
624
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
557
625
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
626
+ type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float , D> :
558
627
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
559
628
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
560
629
nullptr ;
@@ -565,6 +634,7 @@ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
565
634
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
566
635
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
567
636
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
637
+ type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0<half> :
568
638
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
569
639
type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<half> :
570
640
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
@@ -576,6 +646,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
576
646
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float > :
577
647
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float > :
578
648
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float > :
649
+ type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0<float > :
579
650
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float > :
580
651
type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<float > :
581
652
type_V == GGML_TYPE_F16 ? dequantize_1_f16<float > :
@@ -635,11 +706,13 @@ static void on_no_fattn_vec_case(const int D) {
635
706
} else if (D == 128 ) {
636
707
fprintf (stderr, " Unsupported KV type combination for head_size 128.\n " );
637
708
fprintf (stderr, " Supported combinations:\n " );
638
- fprintf (stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n " );
639
- fprintf (stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n " );
640
- fprintf (stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n " );
641
- fprintf (stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n " );
642
- fprintf (stderr, " - K == f16, V == f16, 16.00 BPV\n " );
709
+ fprintf (stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n " );
710
+ fprintf (stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n " );
711
+ fprintf (stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n " );
712
+ fprintf (stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n " );
713
+ fprintf (stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n " );
714
+ fprintf (stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n " );
715
+ fprintf (stderr, " - K == f16, V == f16, 16.0 BPV\n " );
643
716
fprintf (stderr, " Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n " );
644
717
GGML_ABORT (" fatal error" );
645
718
} else {
0 commit comments