@@ -136,6 +136,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
136
136
return sum;
137
137
}
138
138
139
+ static __device__ __forceinline__ int get_one_int_from_table_16 (const int & q4) {
140
+ const uint8_t * q0_8 = (const uint8_t *) &q4;
141
+ const char4 val0_8 = make_char4 (kvalues_iq4nl[q0_8[0 ]], kvalues_iq4nl[q0_8[1 ]], kvalues_iq4nl[q0_8[2 ]], kvalues_iq4nl[q0_8[3 ]]);
142
+ return *((const int *) &val0_8);
143
+ }
144
+
145
+ template <typename T, int D>
146
+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl (
147
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
148
+
149
+ const block_iq4_nl * K_iq4_nl = (const block_iq4_nl *) K_c;
150
+ GGML_UNUSED (Q_v);
151
+
152
+ T sum = 0 .0f ;
153
+
154
+ #pragma unroll
155
+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE) {
156
+ const int k_KQ = k_KQ_0 + threadIdx .x ;
157
+
158
+ const int ib = k_KQ / QI8_1;
159
+ const int iqs4 = k_KQ % QI4_NL;
160
+ const int shift = k_KQ & (QI8_1/2 );
161
+
162
+ const int v = get_one_int_from_table_16 ((get_int_b2 (K_iq4_nl[ib].qs , iqs4) >> shift) & 0x0F0F0F0F );
163
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
164
+
165
+ const int sumi = ggml_cuda_dp4a (v, u, 0 );
166
+
167
+ #ifdef FP16_AVAILABLE
168
+ if (std::is_same<T, half>::value) {
169
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
170
+ sum += (T) (((half)sumi) * K_iq4_nl[ib].d * Q_ds[k_KQ_0/WARP_SIZE].x );
171
+ } else
172
+ #endif // FP16_AVAILABLE
173
+ {
174
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
175
+ sum += (T) ((float )sumi * __half2float (K_iq4_nl[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE].x );
176
+ }
177
+ }
178
+
179
+ return sum;
180
+ }
181
+
139
182
template <typename T, int D>
140
183
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0 (
141
184
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -377,6 +420,25 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
377
420
return ((float ) d)*((float ) q);
378
421
}
379
422
423
+ template <typename T>
424
+ static __device__ __forceinline__ T dequantize_1_iq4_nl (const void * __restrict__ vx, const int64_t i) {
425
+ const block_iq4_nl * x = (const block_iq4_nl *) vx;
426
+
427
+ const int64_t ib = i / QK4_NL;
428
+ const int iqs = i % (QK4_NL/2 );
429
+ const int shift = (i % QK4_NL) / (QK4_NL/2 );
430
+
431
+ #ifdef FP16_AVAILABLE
432
+ if constexpr (std::is_same<T, half>::value) {
433
+ return x[ib].d * ((half) kvalues_iq4nl[(x[ib].qs [iqs] >> 4 *(shift)) & 0xf ]);
434
+ } else {
435
+ return (float )x[ib].d * ((float ) kvalues_iq4nl[(x[ib].qs [iqs] >> 4 *(shift)) & 0xf ]);
436
+ }
437
+ #endif
438
+ T result = (float )x[ib].d * ((float ) kvalues_iq4nl[(x[ib].qs [iqs] >> 4 *(shift)) & 0xf ]);
439
+ return result;
440
+ }
441
+
380
442
template <typename T>
381
443
static __device__ __forceinline__ T dequantize_1_q4_1 (const void * __restrict__ vx, const int64_t i) {
382
444
const block_q4_1 * x = (const block_q4_1 *) vx;
@@ -476,44 +538,48 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
476
538
477
539
template <int D>
478
540
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16 (ggml_type type_K) {
479
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
480
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
481
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
482
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
483
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
484
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
485
- nullptr ;
541
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
542
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
543
+ type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
544
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
545
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
546
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
547
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
548
+ nullptr ;
486
549
}
487
550
488
551
template <int D>
489
552
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32 (ggml_type type_K) {
490
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D> :
491
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D> :
492
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
493
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
494
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
495
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
496
- nullptr ;
553
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D> :
554
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D> :
555
+ type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float , D> :
556
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
557
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
558
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
559
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
560
+ nullptr ;
497
561
}
498
562
499
563
constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16 (ggml_type type_V) {
500
- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
501
- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
502
- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
503
- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
504
- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
505
- type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
506
- nullptr ;
564
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
565
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
566
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
567
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
568
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
569
+ type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<half> :
570
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
571
+ nullptr ;
507
572
}
508
573
509
574
constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32 (ggml_type type_V) {
510
- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float > :
511
- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float > :
512
- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float > :
513
- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float > :
514
- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float > :
515
- type_V == GGML_TYPE_F16 ? dequantize_1_f16<float > :
516
- nullptr ;
575
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float > :
576
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float > :
577
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float > :
578
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float > :
579
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float > :
580
+ type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<float > :
581
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<float > :
582
+ nullptr ;
517
583
}
518
584
519
585
template <int D, int parallel_blocks> // D == head size
@@ -569,10 +635,12 @@ static void on_no_fattn_vec_case(const int D) {
569
635
} else if (D == 128 ) {
570
636
fprintf (stderr, " Unsupported KV type combination for head_size 128.\n " );
571
637
fprintf (stderr, " Supported combinations:\n " );
572
- fprintf (stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n " );
573
- fprintf (stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n " );
574
- fprintf (stderr, " - K == f16, V == f16, 16.00 BPV\n " );
575
- fprintf (stderr, " Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n " );
638
+ fprintf (stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n " );
639
+ fprintf (stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n " );
640
+ fprintf (stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n " );
641
+ fprintf (stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n " );
642
+ fprintf (stderr, " - K == f16, V == f16, 16.00 BPV\n " );
643
+ fprintf (stderr, " Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n " );
576
644
GGML_ABORT (" fatal error" );
577
645
} else {
578
646
fprintf (stderr, " Unsupported KV type combination for head_size 256.\n " );
0 commit comments