@@ -3753,22 +3753,24 @@ template <bool need_check> static __global__ void
3753
3753
template <int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
3754
3754
static __global__ void mul_mat_vec_q (
3755
3755
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
3756
- const int ncols, const int nrows, const int row_delta) {
3756
+ const int ncols, const int nrows, const int row_delta, const int channel_delta_x, const int channel_delta_y ) {
3757
3757
3758
3758
const int row = blockIdx .y *blockDim .y + threadIdx .y ;
3759
3759
3760
3760
if (row >= nrows) {
3761
3761
return ;
3762
3762
}
3763
3763
3764
+ const int channel = blockIdx .z *blockDim .z + threadIdx .z ;
3765
+
3764
3766
const int blocks_per_row = ncols / qk;
3765
3767
const int blocks_per_warp = vdr * WARP_SIZE / qi;
3766
3768
3767
3769
// partial sum for each thread
3768
3770
float tmp = 0 .0f ;
3769
3771
3770
- const block_q_t * x = (const block_q_t *) vx;
3771
- const block_q8_1 * y = (const block_q8_1 *) vy;
3772
+ const block_q_t * x = (( const block_q_t *) vx) + channel*channel_delta_x ;
3773
+ const block_q8_1 * y = (( const block_q8_1 *) vy) + channel*channel_delta_y ;
3772
3774
3773
3775
for (int i = 0 ; i < blocks_per_row; i += blocks_per_warp) {
3774
3776
const int ibx = row*row_delta + i + threadIdx .x / (qi/vdr); // x block index
@@ -3787,7 +3789,7 @@ static __global__ void mul_mat_vec_q(
3787
3789
}
3788
3790
3789
3791
if (threadIdx .x == 0 ) {
3790
- dst[row] = tmp;
3792
+ dst[channel*nrows + row] = tmp;
3791
3793
}
3792
3794
}
3793
3795
@@ -4439,94 +4441,124 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
4439
4441
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
4440
4442
}
4441
4443
4442
- static void mul_mat_vec_q4_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4444
+ static void mul_mat_vec_q4_0_q8_1_cuda (
4445
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4446
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4447
+
4443
4448
GGML_ASSERT (ncols % QK4_0 == 0 );
4444
4449
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4445
- const dim3 block_nums (1 , block_num_y, 1 );
4450
+ const dim3 block_nums (1 , block_num_y, nchannels );
4446
4451
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4447
4452
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
4448
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4453
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4449
4454
}
4450
4455
4451
- static void mul_mat_vec_q4_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4456
+ static void mul_mat_vec_q4_1_q8_1_cuda (
4457
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4458
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4459
+
4452
4460
GGML_ASSERT (ncols % QK4_1 == 0 );
4453
4461
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4454
- const dim3 block_nums (1 , block_num_y, 1 );
4462
+ const dim3 block_nums (1 , block_num_y, nchannels );
4455
4463
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4456
4464
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
4457
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4465
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4458
4466
}
4459
4467
4460
- static void mul_mat_vec_q5_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4468
+ static void mul_mat_vec_q5_0_q8_1_cuda (
4469
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4470
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4471
+
4461
4472
GGML_ASSERT (ncols % QK5_0 == 0 );
4462
4473
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4463
- const dim3 block_nums (1 , block_num_y, 1 );
4474
+ const dim3 block_nums (1 , block_num_y, nchannels );
4464
4475
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4465
4476
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
4466
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4477
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4467
4478
}
4468
4479
4469
- static void mul_mat_vec_q5_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4480
+ static void mul_mat_vec_q5_1_q8_1_cuda (
4481
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4482
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4483
+
4470
4484
GGML_ASSERT (ncols % QK5_1 == 0 );
4471
4485
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4472
- const dim3 block_nums (1 , block_num_y, 1 );
4486
+ const dim3 block_nums (1 , block_num_y, nchannels );
4473
4487
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4474
4488
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
4475
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4489
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4476
4490
}
4477
4491
4478
- static void mul_mat_vec_q8_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4492
+ static void mul_mat_vec_q8_0_q8_1_cuda (
4493
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4494
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4495
+
4479
4496
GGML_ASSERT (ncols % QK8_0 == 0 );
4480
4497
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4481
- const dim3 block_nums (1 , block_num_y, 1 );
4498
+ const dim3 block_nums (1 , block_num_y, nchannels );
4482
4499
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4483
4500
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
4484
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4501
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4485
4502
}
4486
4503
4487
- static void mul_mat_vec_q2_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4504
+ static void mul_mat_vec_q2_K_q8_1_cuda (
4505
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4506
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4507
+
4488
4508
GGML_ASSERT (ncols % QK_K == 0 );
4489
4509
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4490
- const dim3 block_nums (1 , block_num_y, 1 );
4510
+ const dim3 block_nums (1 , block_num_y, nchannels );
4491
4511
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4492
4512
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
4493
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4513
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4494
4514
}
4495
4515
4496
- static void mul_mat_vec_q3_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4516
+ static void mul_mat_vec_q3_K_q8_1_cuda (
4517
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4518
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4519
+
4497
4520
GGML_ASSERT (ncols % QK_K == 0 );
4498
4521
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4499
- const dim3 block_nums (1 , block_num_y, 1 );
4522
+ const dim3 block_nums (1 , block_num_y, nchannels );
4500
4523
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4501
4524
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
4502
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4525
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4503
4526
}
4504
4527
4505
- static void mul_mat_vec_q4_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4528
+ static void mul_mat_vec_q4_K_q8_1_cuda (
4529
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4530
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4531
+
4506
4532
GGML_ASSERT (ncols % QK_K == 0 );
4507
4533
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4508
- const dim3 block_nums (1 , block_num_y, 1 );
4534
+ const dim3 block_nums (1 , block_num_y, nchannels );
4509
4535
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4510
4536
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
4511
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4537
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4512
4538
}
4513
4539
4514
- static void mul_mat_vec_q5_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4540
+ static void mul_mat_vec_q5_K_q8_1_cuda (
4541
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4542
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4543
+
4515
4544
GGML_ASSERT (ncols % QK_K == 0 );
4516
4545
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4517
- const dim3 block_nums (1 , block_num_y, 1 );
4546
+ const dim3 block_nums (1 , block_num_y, nchannels );
4518
4547
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4519
4548
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
4520
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4549
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4521
4550
}
4522
4551
4523
- static void mul_mat_vec_q6_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int row_delta, cudaStream_t stream) {
4552
+ static void mul_mat_vec_q6_K_q8_1_cuda (
4553
+ const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels,
4554
+ const int row_delta, const int channel_delta, const int channel_delta_y, cudaStream_t stream) {
4555
+
4524
4556
GGML_ASSERT (ncols % QK_K == 0 );
4525
4557
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
4526
- const dim3 block_nums (1 , block_num_y, 1 );
4558
+ const dim3 block_nums (1 , block_num_y, nchannels );
4527
4559
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
4528
4560
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
4529
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta);
4561
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows, row_delta, channel_delta, channel_delta_y );
4530
4562
}
4531
4563
4532
4564
static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -5559,7 +5591,13 @@ inline void ggml_cuda_op_mul_mat_vec(
5559
5591
GGML_ASSERT (dst_ddf_i != nullptr );
5560
5592
5561
5593
const int64_t ne00 = src0->ne [0 ];
5594
+ const int64_t ne02 = src0->ne [2 ];
5595
+
5596
+ const int64_t ne10 = src1->ne [0 ];
5597
+
5562
5598
const int64_t nb01 = src0->nb [1 ];
5599
+ const int64_t nb02 = src0->nb [2 ];
5600
+
5563
5601
const int64_t nrows = i01_high - i01_low;
5564
5602
5565
5603
#ifdef GGML_CUDA_FORCE_DMMV
@@ -5585,46 +5623,48 @@ inline void ggml_cuda_op_mul_mat_vec(
5585
5623
#endif // QK_K == 256
5586
5624
5587
5625
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented;
5588
- #endif
5626
+ #endif // GGML_CUDA_FORCE_DMMV
5589
5627
5590
5628
if (use_mul_mat_vec_q) {
5591
- const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
5592
- ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
5629
+ const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
5630
+ ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
5593
5631
size_t as;
5594
- void * src1_q8_1 = ggml_cuda_pool_malloc (padded_row_size*sizeof (block_q8_1)/QK8_1, &as);
5595
- quantize_row_q8_1_cuda (src1_ddf_i, src1_q8_1, ne00, 1 , padded_row_size, cudaStream_main);
5632
+ void * src1_q8_1 = ggml_cuda_pool_malloc (padded_row_size*ne02* sizeof (block_q8_1)/QK8_1, &as);
5633
+ quantize_row_q8_1_cuda (src1_ddf_i, src1_q8_1, ne10, ne02 , padded_row_size, cudaStream_main);
5596
5634
5597
- const int row_delta = nb01 / ggml_type_size (src0->type );
5635
+ const int row_delta = nb01 / ggml_type_size (src0->type );
5636
+ const int channel_delta = nb02 / ggml_type_size (src0->type );
5637
+ const int channel_delta_y = padded_row_size / QK8_1;
5598
5638
switch (src0->type ) {
5599
5639
case GGML_TYPE_Q4_0:
5600
- mul_mat_vec_q4_0_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5640
+ mul_mat_vec_q4_0_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5601
5641
break ;
5602
5642
case GGML_TYPE_Q4_1:
5603
- mul_mat_vec_q4_1_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5643
+ mul_mat_vec_q4_1_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5604
5644
break ;
5605
5645
case GGML_TYPE_Q5_0:
5606
- mul_mat_vec_q5_0_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5646
+ mul_mat_vec_q5_0_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5607
5647
break ;
5608
5648
case GGML_TYPE_Q5_1:
5609
- mul_mat_vec_q5_1_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5649
+ mul_mat_vec_q5_1_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5610
5650
break ;
5611
5651
case GGML_TYPE_Q8_0:
5612
- mul_mat_vec_q8_0_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5652
+ mul_mat_vec_q8_0_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5613
5653
break ;
5614
5654
case GGML_TYPE_Q2_K:
5615
- mul_mat_vec_q2_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5655
+ mul_mat_vec_q2_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5616
5656
break ;
5617
5657
case GGML_TYPE_Q3_K:
5618
- mul_mat_vec_q3_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5658
+ mul_mat_vec_q3_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5619
5659
break ;
5620
5660
case GGML_TYPE_Q4_K:
5621
- mul_mat_vec_q4_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5661
+ mul_mat_vec_q4_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5622
5662
break ;
5623
5663
case GGML_TYPE_Q5_K:
5624
- mul_mat_vec_q5_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5664
+ mul_mat_vec_q5_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5625
5665
break ;
5626
5666
case GGML_TYPE_Q6_K:
5627
- mul_mat_vec_q6_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, row_delta, cudaStream_main);
5667
+ mul_mat_vec_q6_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_delta, channel_delta, channel_delta_y , cudaStream_main);
5628
5668
break ;
5629
5669
default :
5630
5670
GGML_ASSERT (false );
@@ -5633,6 +5673,8 @@ inline void ggml_cuda_op_mul_mat_vec(
5633
5673
5634
5674
ggml_cuda_pool_free (src1_q8_1, as);
5635
5675
} else {
5676
+ GGML_ASSERT (ne02 == 1 );
5677
+
5636
5678
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
5637
5679
#ifdef GGML_CUDA_F16
5638
5680
size_t ash;
@@ -6320,7 +6362,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
6320
6362
GGML_ASSERT (!ggml_is_contiguous (src0) && ggml_is_contiguous (src1));
6321
6363
GGML_ASSERT (!ggml_is_permuted (src0));
6322
6364
GGML_ASSERT (src0->backend != GGML_BACKEND_GPU_SPLIT);
6323
- GGML_ASSERT (src0->type == GGML_TYPE_F16);
6324
6365
GGML_ASSERT (src1->type == GGML_TYPE_F32);
6325
6366
6326
6367
const int64_t ne00 = src0->ne [0 ];
@@ -6336,18 +6377,24 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
6336
6377
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
6337
6378
6338
6379
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra ;
6339
- void * src0_ddq = src0_extra->data_device [g_main_device];
6380
+ char * src0_ddq = ( char *) src0_extra->data_device [g_main_device];
6340
6381
6341
6382
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra ;
6342
6383
float * src1_ddf = (float *) src1_extra->data_device [g_main_device];
6343
6384
6344
6385
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra ;
6345
6386
float * dst_ddf = (float *) dst_extra->data_device [g_main_device];
6346
6387
6347
- const int row_stride_x = nb01 / sizeof (half);
6348
- const int channel_stride_x = nb02 / sizeof (half);
6388
+ if (src0->type == GGML_TYPE_F16) {
6389
+ const int row_stride_x = nb01 / sizeof (half);
6390
+ const int channel_stride_x = nb02 / sizeof (half);
6349
6391
6350
- ggml_mul_mat_vec_nc_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
6392
+ ggml_mul_mat_vec_nc_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
6393
+ } else if (ggml_is_quantized (src0->type )) {
6394
+ ggml_cuda_op_mul_mat_vec (src0, src1, dst, src0_ddq, nullptr , src1_ddf, dst_ddf, 0 , 0 , ne01, 0 , cudaStream_main);
6395
+ } else {
6396
+ GGML_ASSERT (false );
6397
+ }
6351
6398
}
6352
6399
6353
6400
void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -6357,7 +6404,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
6357
6404
6358
6405
if (all_on_device && !src0_is_quantized && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
6359
6406
ggml_cuda_mul_mat_vec_p021 (src0, src1, dst);
6360
- } else if (all_on_device && !src0_is_quantized && ! ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && src1->ne [1 ] == 1 ) {
6407
+ } else if (all_on_device && !ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && src1->ne [1 ] == 1 ) {
6361
6408
ggml_cuda_mul_mat_vec_nc (src0, src1, dst);
6362
6409
}else if (src0->type == GGML_TYPE_F32) {
6363
6410
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true , false );
0 commit comments