Skip to content

Commit 104138c

Browse files
mul_mat_vev_q template
1 parent afa8885 commit 104138c

File tree

1 file changed

+30
-23
lines changed

1 file changed

+30
-23
lines changed

ggml-cuda.cu

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ typedef struct {
115115
} block_q8_0;
116116
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
117117

118+
typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_0 * bq8_0, const int iqs);
119+
118120
//================================= k-quants
119121

120122
#ifdef GGML_QKK_64
@@ -1186,6 +1188,27 @@ static __global__ void quantize_q8_0(const float * x, void * vy, const int k) {
11861188
y[ib].d = d;
11871189
}
11881190

1191+
static __device__ float vec_dot_q4_0_q8_0(const void * vbq, const block_q8_0 * bq8_0, const int iqs) {
1192+
const block_q8_0 * bq4_0 = (const block_q8_0 *) vbq;
1193+
1194+
int vi;
1195+
int ui0, ui1;
1196+
memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
1197+
memcpy(&ui0, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
1198+
memcpy(&ui1, &bq8_0->qs[sizeof(int) * (iqs + 4)], sizeof(int));
1199+
1200+
const float d = bq4_0->d * bq8_0->d;
1201+
1202+
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
1203+
const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
1204+
1205+
const int sumi0 = __dp4a(vi0, ui0, 0);
1206+
const int sumi1 = __dp4a(vi1, ui1, 0);
1207+
1208+
return (sumi0 + sumi1)*d;
1209+
1210+
}
1211+
11891212
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
11901213
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
11911214
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
@@ -1207,8 +1230,8 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
12071230
y[iybs + iqs + y_offset] = v.y;
12081231
}
12091232

1210-
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1211-
static __global__ void vec_dot_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
1233+
template <int qk, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
1234+
static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
12121235
const int row = blockIdx.y*blockDim.y + threadIdx.y;
12131236

12141237
if (row >= nrows) {
@@ -1224,33 +1247,17 @@ static __global__ void vec_dot_q(const void * vx, const void * vy, float * dst,
12241247
// partial sum for each thread
12251248
float tmp = 0.0f;
12261249

1227-
const block_q4_0 * x = (const block_q4_0 *) vx;
1228-
const block_q8_0 * y = (block_q8_0 *) vy;
1250+
const block_q_t * x = (const block_q_t *) vx;
1251+
const block_q8_0 * y = (const block_q8_0 *) vy;
12291252

12301253
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
12311254
const int ibx = row*blocks_per_row + i + tid/ints_per_block; // x block index
12321255

12331256
const int iby = i + tid/ints_per_block;
12341257

1235-
const int iqsx = tid % ints_per_block;
1236-
const int iqsy0 = tid % ints_per_block + 0;
1237-
const int iqsy1 = tid % ints_per_block + ints_per_block;
1238-
1239-
int vi;
1240-
int ui0, ui1;
1241-
memcpy(&vi, &x[ibx].qs[sizeof(int) * iqsx], sizeof(int));
1242-
memcpy(&ui0, &y[iby].qs[sizeof(int) * iqsy0], sizeof(int));
1243-
memcpy(&ui1, &y[iby].qs[sizeof(int) * iqsy1], sizeof(int));
1244-
1245-
const dfloat d = x[ibx].d * y[iby].d;
1246-
1247-
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
1248-
const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
1249-
1250-
const int sumi0 = __dp4a(vi0, ui0, 0);
1251-
const int sumi1 = __dp4a(vi1, ui1, 0);
1258+
const int iqs = tid % ints_per_block;
12521259

1253-
tmp += (sumi0 + sumi1)*d;
1260+
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
12541261
}
12551262

12561263
// sum up partial sums and write back result
@@ -1743,7 +1750,7 @@ static void mul_mat_vec_q4_0_q8_0_cuda(const void * vx, const void * vy, float *
17431750
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
17441751
const dim3 block_nums(1, block_num_y, 1);
17451752
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
1746-
vec_dot_q<QK4_0, QR4_0, dequantize_q4_0>
1753+
mul_mat_vec_q<QK4_0, block_q4_0, vec_dot_q4_0_q8_0>
17471754
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
17481755
}
17491756

0 commit comments

Comments
 (0)