Skip to content

Commit 9fb224c

Browse files
mul_mat_vec_q for -ngl 34
1 parent 27c9088 commit 9fb224c

File tree

1 file changed

+103
-56
lines changed

1 file changed

+103
-56
lines changed

ggml-cuda.cu

Lines changed: 103 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3753,22 +3753,24 @@ template <bool need_check> static __global__ void
37533753
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
37543754
static __global__ void mul_mat_vec_q(
37553755
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
3756-
const int ncols, const int nrows, const int row_delta) {
3756+
const int ncols, const int nrows, const int row_delta, const int channel_delta_x, const int channel_delta_y) {
37573757

37583758
const int row = blockIdx.y*blockDim.y + threadIdx.y;
37593759

37603760
if (row >= nrows) {
37613761
return;
37623762
}
37633763

3764+
const int channel = blockIdx.z*blockDim.z + threadIdx.z;
3765+
37643766
const int blocks_per_row = ncols / qk;
37653767
const int blocks_per_warp = vdr * WARP_SIZE / qi;
37663768

37673769
// partial sum for each thread
37683770
float tmp = 0.0f;
37693771

3770-
const block_q_t * x = (const block_q_t *) vx;
3771-
const block_q8_1 * y = (const block_q8_1 *) vy;
3772+
const block_q_t * x = ((const block_q_t *) vx) + channel*channel_delta_x;
3773+
const block_q8_1 * y = ((const block_q8_1 *) vy) + channel*channel_delta_y;
37723774

37733775
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
37743776
const int ibx = row*row_delta + i + threadIdx.x / (qi/vdr); // x block index
@@ -3787,7 +3789,7 @@ static __global__ void mul_mat_vec_q(
37873789
}
37883790

37893791
if (threadIdx.x == 0) {
3790-
dst[row] = tmp;
3792+
dst[channel*nrows + row] = tmp;
37913793
}
37923794
}
37933795

@@ -4439,94 +4441,124 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
44394441
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
44404442
}
44414443

4442-
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4444+
static void mul_mat_vec_q4_0_q8_1_cuda(
4445+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4446+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4447+
44434448
GGML_ASSERT(ncols % QK4_0 == 0);
44444449
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4445-
const dim3 block_nums(1, block_num_y, 1);
4450+
const dim3 block_nums(1, block_num_y, nchannels);
44464451
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
44474452
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
4448-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4453+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
44494454
}
44504455

4451-
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4456+
static void mul_mat_vec_q4_1_q8_1_cuda(
4457+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4458+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4459+
44524460
GGML_ASSERT(ncols % QK4_1 == 0);
44534461
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4454-
const dim3 block_nums(1, block_num_y, 1);
4462+
const dim3 block_nums(1, block_num_y, nchannels);
44554463
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
44564464
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
4457-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4465+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
44584466
}
44594467

4460-
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4468+
static void mul_mat_vec_q5_0_q8_1_cuda(
4469+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4470+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4471+
44614472
GGML_ASSERT(ncols % QK5_0 == 0);
44624473
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4463-
const dim3 block_nums(1, block_num_y, 1);
4474+
const dim3 block_nums(1, block_num_y, nchannels);
44644475
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
44654476
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
4466-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4477+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
44674478
}
44684479

4469-
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4480+
static void mul_mat_vec_q5_1_q8_1_cuda(
4481+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4482+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4483+
44704484
GGML_ASSERT(ncols % QK5_1 == 0);
44714485
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4472-
const dim3 block_nums(1, block_num_y, 1);
4486+
const dim3 block_nums(1, block_num_y, nchannels);
44734487
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
44744488
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
4475-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4489+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
44764490
}
44774491

4478-
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4492+
static void mul_mat_vec_q8_0_q8_1_cuda(
4493+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4494+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4495+
44794496
GGML_ASSERT(ncols % QK8_0 == 0);
44804497
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4481-
const dim3 block_nums(1, block_num_y, 1);
4498+
const dim3 block_nums(1, block_num_y, nchannels);
44824499
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
44834500
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
4484-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4501+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
44854502
}
44864503

4487-
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4504+
static void mul_mat_vec_q2_K_q8_1_cuda(
4505+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4506+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4507+
44884508
GGML_ASSERT(ncols % QK_K == 0);
44894509
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4490-
const dim3 block_nums(1, block_num_y, 1);
4510+
const dim3 block_nums(1, block_num_y, nchannels);
44914511
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
44924512
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
4493-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4513+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
44944514
}
44954515

4496-
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4516+
static void mul_mat_vec_q3_K_q8_1_cuda(
4517+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4518+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4519+
44974520
GGML_ASSERT(ncols % QK_K == 0);
44984521
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4499-
const dim3 block_nums(1, block_num_y, 1);
4522+
const dim3 block_nums(1, block_num_y, nchannels);
45004523
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
45014524
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
4502-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4525+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
45034526
}
45044527

4505-
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4528+
static void mul_mat_vec_q4_K_q8_1_cuda(
4529+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4530+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4531+
45064532
GGML_ASSERT(ncols % QK_K == 0);
45074533
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4508-
const dim3 block_nums(1, block_num_y, 1);
4534+
const dim3 block_nums(1, block_num_y, nchannels);
45094535
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
45104536
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
4511-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4537+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
45124538
}
45134539

4514-
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4540+
static void mul_mat_vec_q5_K_q8_1_cuda(
4541+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4542+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4543+
45154544
GGML_ASSERT(ncols % QK_K == 0);
45164545
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4517-
const dim3 block_nums(1, block_num_y, 1);
4546+
const dim3 block_nums(1, block_num_y, nchannels);
45184547
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
45194548
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
4520-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4549+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
45214550
}
45224551

4523-
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4552+
static void mul_mat_vec_q6_K_q8_1_cuda(
4553+
const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4554+
const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4555+
45244556
GGML_ASSERT(ncols % QK_K == 0);
45254557
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4526-
const dim3 block_nums(1, block_num_y, 1);
4558+
const dim3 block_nums(1, block_num_y, nchannels);
45274559
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
45284560
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
4529-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta);
4561+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y);
45304562
}
45314563

45324564
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -5559,7 +5591,13 @@ inline void ggml_cuda_op_mul_mat_vec(
55595591
GGML_ASSERT(dst_ddf_i != nullptr);
55605592

55615593
const int64_t ne00 = src0->ne[0];
5594+
const int64_t ne02 = src0->ne[2];
5595+
5596+
const int64_t ne10 = src1->ne[0];
5597+
55625598
const int64_t nb01 = src0->nb[1];
5599+
const int64_t nb02 = src0->nb[2];
5600+
55635601
const int64_t nrows = i01_high - i01_low;
55645602

55655603
#ifdef GGML_CUDA_FORCE_DMMV
@@ -5585,46 +5623,48 @@ inline void ggml_cuda_op_mul_mat_vec(
55855623
#endif // QK_K == 256
55865624

55875625
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented;
5588-
#endif
5626+
#endif // GGML_CUDA_FORCE_DMMV
55895627

55905628
if (use_mul_mat_vec_q) {
5591-
const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
5592-
ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
5629+
const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
5630+
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
55935631
size_t as;
5594-
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
5595-
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main);
5632+
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne02*sizeof(block_q8_1)/QK8_1, &as);
5633+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne02, padded_row_size, cudaStream_main);
55965634

5597-
const int row_delta = nb01 / ggml_type_size(src0->type);
5635+
const int row_delta = nb01 / ggml_type_size(src0->type);
5636+
const int channel_delta = nb02 / ggml_type_size(src0->type);
5637+
const int channel_delta_y = padded_row_size / QK8_1;
55985638
switch (src0->type) {
55995639
case GGML_TYPE_Q4_0:
5600-
mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5640+
mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56015641
break;
56025642
case GGML_TYPE_Q4_1:
5603-
mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5643+
mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56045644
break;
56055645
case GGML_TYPE_Q5_0:
5606-
mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5646+
mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56075647
break;
56085648
case GGML_TYPE_Q5_1:
5609-
mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5649+
mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56105650
break;
56115651
case GGML_TYPE_Q8_0:
5612-
mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5652+
mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56135653
break;
56145654
case GGML_TYPE_Q2_K:
5615-
mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5655+
mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56165656
break;
56175657
case GGML_TYPE_Q3_K:
5618-
mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5658+
mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56195659
break;
56205660
case GGML_TYPE_Q4_K:
5621-
mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5661+
mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56225662
break;
56235663
case GGML_TYPE_Q5_K:
5624-
mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5664+
mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56255665
break;
56265666
case GGML_TYPE_Q6_K:
5627-
mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5667+
mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y, cudaStream_main);
56285668
break;
56295669
default:
56305670
GGML_ASSERT(false);
@@ -5633,6 +5673,8 @@ inline void ggml_cuda_op_mul_mat_vec(
56335673

56345674
ggml_cuda_pool_free(src1_q8_1, as);
56355675
} else {
5676+
GGML_ASSERT(ne02 == 1);
5677+
56365678
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
56375679
#ifdef GGML_CUDA_F16
56385680
size_t ash;
@@ -6320,7 +6362,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
63206362
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
63216363
GGML_ASSERT(!ggml_is_permuted(src0));
63226364
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
6323-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
63246365
GGML_ASSERT(src1->type == GGML_TYPE_F32);
63256366

63266367
const int64_t ne00 = src0->ne[0];
@@ -6336,18 +6377,24 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
63366377
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
63376378

63386379
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
6339-
void * src0_ddq = src0_extra->data_device[g_main_device];
6380+
char * src0_ddq = (char *) src0_extra->data_device[g_main_device];
63406381

63416382
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
63426383
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
63436384

63446385
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
63456386
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
63466387

6347-
const int row_stride_x = nb01 / sizeof(half);
6348-
const int channel_stride_x = nb02 / sizeof(half);
6388+
if (src0->type == GGML_TYPE_F16) {
6389+
const int row_stride_x = nb01 / sizeof(half);
6390+
const int channel_stride_x = nb02 / sizeof(half);
63496391

6350-
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
6392+
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
6393+
} else if (ggml_is_quantized(src0->type)) {
6394+
ggml_cuda_op_mul_mat_vec(src0, src1, dst, src0_ddq, nullptr, src1_ddf, dst_ddf, 0, 0, ne01, 0, cudaStream_main);
6395+
} else {
6396+
GGML_ASSERT(false);
6397+
}
63516398
}
63526399

63536400
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -6357,7 +6404,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
63576404

63586405
if (all_on_device && !src0_is_quantized && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
63596406
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
6360-
} else if (all_on_device && !src0_is_quantized && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
6407+
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
63616408
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
63626409
}else if (src0->type == GGML_TYPE_F32) {
63636410
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);

0 commit comments

Comments
 (0)