2
2
#include " common.cuh"
3
3
#include " mmv.cuh"
4
4
5
+ template <typename T, typename type_acc, int ncols_dst, int block_size>
5
6
template <typename T, typename type_acc, int ncols_dst, int block_size>
6
7
static __global__ void mul_mat_vec (
7
8
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
@@ -15,10 +16,25 @@ static __global__ void mul_mat_vec(
15
16
const int sample_dst = blockIdx .z ;
16
17
const int sample_x = sample_dst / sample_ratio;
17
18
const int sample_y = sample_dst;
19
+ const int tid = threadIdx .x ;
20
+
21
+ const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
22
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
23
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
24
+ const int row = blockIdx .x ;
25
+ const int channel_dst = blockIdx .y ;
26
+ const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
27
+ const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
28
+ const int sample_dst = blockIdx .z ;
29
+ const int sample_x = sample_dst / sample_ratio;
30
+ const int sample_y = sample_dst;
18
31
const int tid = threadIdx .x ;
19
32
20
33
constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
21
34
35
+ x += int64_t (sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
36
+ y += int64_t (sample_y) *stride_sample_y + channel_y *stride_channel_y;
37
+ dst += int64_t (sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
22
38
x += int64_t (sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
23
39
y += int64_t (sample_y) *stride_sample_y + channel_y *stride_channel_y;
24
40
dst += int64_t (sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
@@ -456,11 +472,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
456
472
return ne11 <= 4 ;
457
473
}
458
474
return ne11 <= 3 ;
459
- } else if (GGML_CUDA_CC_IS_AMD (cc)) {
460
- if (fp32_mma_hardware_available (cc)) {
461
- return ne11 <= 3 ;
462
- }
463
- return ne11 <= 8 ;
464
475
}
465
476
return ne11 <= 8 ;
466
477
case GGML_TYPE_F16:
@@ -473,14 +484,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
473
484
return src0_small && ne11 <= 3 ;
474
485
}
475
486
return ne11 <= 8 ;
476
- } else if (GGML_CUDA_CC_IS_AMD (cc)) {
477
- if (fp16_mma_hardware_available (cc)) {
478
- if (GGML_CUDA_CC_IS_RDNA3 (cc) || GGML_CUDA_CC_IS_RDNA4 (cc)) {
479
- return ne11 <= 5 ;
480
- }
481
- return ne11 <= 2 ;
482
- }
483
- return ne11 <= 8 ;
484
487
}
485
488
return ne11 <= 8 ;
486
489
case GGML_TYPE_BF16:
@@ -493,11 +496,6 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
493
496
return src0_small && ne11 <= 3 ;
494
497
}
495
498
return ne11 <= 8 ;
496
- } else if (GGML_CUDA_CC_IS_AMD (cc)) {
497
- if (bf16_mma_hardware_available (cc)) {
498
- return ne11 <= 3 ;
499
- }
500
- return ne11 <= 8 ;
501
499
}
502
500
return ne11 <= 8 ;
503
501
default :
0 commit comments