@@ -115,6 +115,8 @@ typedef struct {
115
115
} block_q8_0;
116
116
static_assert (sizeof (block_q8_0) == sizeof(ggml_fp16_t ) + QK8_0, "wrong q8_0 block size/padding");
117
117
118
+ typedef float (*vec_dot_q_cuda_t )(const void * vbq, const block_q8_0 * bq8_0, const int iqs);
119
+
118
120
// ================================= k-quants
119
121
120
122
#ifdef GGML_QKK_64
@@ -1186,6 +1188,27 @@ static __global__ void quantize_q8_0(const float * x, void * vy, const int k) {
1186
1188
y[ib].d = d;
1187
1189
}
1188
1190
1191
+ static __device__ float vec_dot_q4_0_q8_0 (const void * vbq, const block_q8_0 * bq8_0, const int iqs) {
1192
+ const block_q8_0 * bq4_0 = (const block_q8_0 *) vbq;
1193
+
1194
+ int vi;
1195
+ int ui0, ui1;
1196
+ memcpy (&vi, &bq4_0->qs [sizeof (int ) * (iqs + 0 )], sizeof (int ));
1197
+ memcpy (&ui0, &bq8_0->qs [sizeof (int ) * (iqs + 0 )], sizeof (int ));
1198
+ memcpy (&ui1, &bq8_0->qs [sizeof (int ) * (iqs + 4 )], sizeof (int ));
1199
+
1200
+ const float d = bq4_0->d * bq8_0->d ;
1201
+
1202
+ const int vi0 = __vsub4 ((vi >> 0 ) & 0x0F0F0F0F , 0x08080808 );
1203
+ const int vi1 = __vsub4 ((vi >> 4 ) & 0x0F0F0F0F , 0x08080808 );
1204
+
1205
+ const int sumi0 = __dp4a (vi0, ui0, 0 );
1206
+ const int sumi1 = __dp4a (vi1, ui1, 0 );
1207
+
1208
+ return (sumi0 + sumi1)*d;
1209
+
1210
+ }
1211
+
1189
1212
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1190
1213
static __global__ void dequantize_block (const void * vx, float * y, const int k) {
1191
1214
const int i = blockDim .x *blockIdx .x + 2 *threadIdx .x ;
@@ -1207,8 +1230,8 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
1207
1230
y[iybs + iqs + y_offset] = v.y ;
1208
1231
}
1209
1232
1210
- template <int qk, int qr, dequantize_kernel_t dequantize_kernel >
1211
- static __global__ void vec_dot_q (const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
1233
+ template <int qk, typename block_q_t , vec_dot_q_cuda_t vec_dot_q_cuda >
1234
+ static __global__ void mul_mat_vec_q (const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
1212
1235
const int row = blockIdx .y *blockDim .y + threadIdx .y ;
1213
1236
1214
1237
if (row >= nrows) {
@@ -1224,33 +1247,17 @@ static __global__ void vec_dot_q(const void * vx, const void * vy, float * dst,
1224
1247
// partial sum for each thread
1225
1248
float tmp = 0 .0f ;
1226
1249
1227
- const block_q4_0 * x = (const block_q4_0 *) vx;
1228
- const block_q8_0 * y = (block_q8_0 *) vy;
1250
+ const block_q_t * x = (const block_q_t *) vx;
1251
+ const block_q8_0 * y = (const block_q8_0 *) vy;
1229
1252
1230
1253
for (int i = 0 ; i < blocks_per_row; i += blocks_per_warp) {
1231
1254
const int ibx = row*blocks_per_row + i + tid/ints_per_block; // x block index
1232
1255
1233
1256
const int iby = i + tid/ints_per_block;
1234
1257
1235
- const int iqsx = tid % ints_per_block;
1236
- const int iqsy0 = tid % ints_per_block + 0 ;
1237
- const int iqsy1 = tid % ints_per_block + ints_per_block;
1238
-
1239
- int vi;
1240
- int ui0, ui1;
1241
- memcpy (&vi, &x[ibx].qs [sizeof (int ) * iqsx], sizeof (int ));
1242
- memcpy (&ui0, &y[iby].qs [sizeof (int ) * iqsy0], sizeof (int ));
1243
- memcpy (&ui1, &y[iby].qs [sizeof (int ) * iqsy1], sizeof (int ));
1244
-
1245
- const dfloat d = x[ibx].d * y[iby].d ;
1246
-
1247
- const int vi0 = __vsub4 ((vi >> 0 ) & 0x0F0F0F0F , 0x08080808 );
1248
- const int vi1 = __vsub4 ((vi >> 4 ) & 0x0F0F0F0F , 0x08080808 );
1249
-
1250
- const int sumi0 = __dp4a (vi0, ui0, 0 );
1251
- const int sumi1 = __dp4a (vi1, ui1, 0 );
1258
+ const int iqs = tid % ints_per_block;
1252
1259
1253
- tmp += (sumi0 + sumi1)*d ;
1260
+ tmp += vec_dot_q_cuda (&x[ibx], &y[iby], iqs) ;
1254
1261
}
1255
1262
1256
1263
// sum up partial sums and write back result
@@ -1743,7 +1750,7 @@ static void mul_mat_vec_q4_0_q8_0_cuda(const void * vx, const void * vy, float *
1743
1750
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1 ) / GGML_CUDA_DMMV_Y;
1744
1751
const dim3 block_nums (1 , block_num_y, 1 );
1745
1752
const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
1746
- vec_dot_q <QK4_0, QR4_0, dequantize_q4_0 >
1753
+ mul_mat_vec_q <QK4_0, block_q4_0, vec_dot_q4_0_q8_0 >
1747
1754
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
1748
1755
}
1749
1756
0 commit comments