@@ -2522,7 +2522,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
2522
2522
static __device__ __forceinline__ void mul_mat_q_process_tile (
2523
2523
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
2524
2524
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525
- const int nrows_x, const int ncols_y , const int stride_row_x , const int stride_col_dst,
2525
+ const int nrows_x, const int stride_row_x , const int ncols_y , const int stride_col_dst,
2526
2526
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
2527
2527
2528
2528
constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2606,7 +2606,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2606
2606
static __global__ void mul_mat_q (
2607
2607
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
2608
2608
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2609
- const int ncols_x, const int nrows_x, const int ncols_y , const int stride_row_x, const int stride_col_dst,
2609
+ const int ncols_x, const int nrows_x, const int ncols_dst , const int stride_row_x, const int ncols_y , const int stride_col_dst,
2610
2610
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2611
2611
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
2612
2612
@@ -2619,8 +2619,8 @@ static __global__ void mul_mat_q(
2619
2619
constexpr int qk = ggml_cuda_type_traits<type>::qk;
2620
2620
constexpr int mmq_y = get_mmq_y_device ();
2621
2621
2622
- const int ntx = (ncols_y + mmq_x - 1 ) / mmq_x; // Number of tiles x
2623
- const int nty = (nrows_x + mmq_y - 1 ) / mmq_y; // Number of tiles y
2622
+ const int ntx = (ncols_dst + mmq_x - 1 ) / mmq_x; // Number of tiles x
2623
+ const int nty = (nrows_x + mmq_y - 1 ) / mmq_y; // Number of tiles y
2624
2624
2625
2625
// Initialize the ids for writing back data with just the index.
2626
2626
// For regular matrix multiplications this is never changed.
@@ -2648,8 +2648,8 @@ static __global__ void mul_mat_q(
2648
2648
2649
2649
// Defaults for regular matrix multiplication:
2650
2650
int col_low = 0 ;
2651
- int col_high = ncols_y ;
2652
- int col_diff = ncols_y ;
2651
+ int col_high = ncols_dst ;
2652
+ int col_diff = ncols_dst ;
2653
2653
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2654
2654
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2655
2655
@@ -2689,7 +2689,7 @@ static __global__ void mul_mat_q(
2689
2689
2690
2690
constexpr bool fixup = false ;
2691
2691
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2692
- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x , stride_col_dst,
2692
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y , stride_col_dst,
2693
2693
tile_x_max_i, tile_y_max_j, 0 , ncols_x/qk);
2694
2694
return ;
2695
2695
}
@@ -2720,8 +2720,8 @@ static __global__ void mul_mat_q(
2720
2720
2721
2721
// Defaults for regular matrix multiplication:
2722
2722
int col_low = 0 ;
2723
- int col_high = ncols_y ;
2724
- int col_diff = ncols_y ;
2723
+ int col_high = ncols_dst ;
2724
+ int col_diff = ncols_dst ;
2725
2725
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2726
2726
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2727
2727
@@ -2767,7 +2767,7 @@ static __global__ void mul_mat_q(
2767
2767
2768
2768
constexpr bool fixup = false ; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
2769
2769
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2770
- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x , stride_col_dst,
2770
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y , stride_col_dst,
2771
2771
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2772
2772
2773
2773
kbc += blocks_per_ne00;
@@ -2792,8 +2792,8 @@ static __global__ void mul_mat_q(
2792
2792
2793
2793
// Defaults for regular matrix multiplication:
2794
2794
int col_low = 0 ;
2795
- int col_high = ncols_y ;
2796
- int col_diff = ncols_y ;
2795
+ int col_high = ncols_dst ;
2796
+ int col_diff = ncols_dst ;
2797
2797
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2798
2798
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2799
2799
@@ -2834,15 +2834,15 @@ static __global__ void mul_mat_q(
2834
2834
2835
2835
constexpr bool fixup = true ; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2836
2836
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2837
- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x , stride_col_dst,
2837
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y , stride_col_dst,
2838
2838
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2839
2839
}
2840
2840
2841
2841
2842
2842
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2843
2843
static __global__ void mul_mat_q_stream_k_fixup (
2844
2844
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2845
- const int ncols_x, const int nrows_x, const int ncols_y , const int stride_col_dst,
2845
+ const int ncols_x, const int nrows_x, const int ncols_dst , const int stride_col_dst,
2846
2846
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
2847
2847
constexpr int mmq_y = get_mmq_y_device ();
2848
2848
constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2851,8 +2851,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
2851
2851
2852
2852
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0 .0f };
2853
2853
2854
- const int ntx = (ncols_y + mmq_x - 1 ) / mmq_x;
2855
- const int nty = (nrows_x + mmq_y - 1 ) / mmq_y;
2854
+ const int ntx = (ncols_dst + mmq_x - 1 ) / mmq_x;
2855
+ const int nty = (nrows_x + mmq_y - 1 ) / mmq_y;
2856
2856
2857
2857
const int bidx0 = blockIdx .x ;
2858
2858
@@ -2925,8 +2925,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
2925
2925
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
2926
2926
dst += offset_dst;
2927
2927
2928
- const int i_max = nrows_x - it*mmq_y - 1 ;
2929
- const int j_max = ncols_y - jt*mmq_x - 1 ;
2928
+ const int i_max = nrows_x - it*mmq_y - 1 ;
2929
+ const int j_max = ncols_dst - jt*mmq_x - 1 ;
2930
2930
2931
2931
#pragma unroll
2932
2932
for (int j0 = 0 ; j0 < mmq_x; j0 += nwarps) {
@@ -2989,7 +2989,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
2989
2989
2990
2990
struct mmq_args {
2991
2991
const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
2992
- int64_t ncols_x; int64_t nrows_x; int64_t ncols_y ; int64_t stride_row_x; int64_t nrows_dst;
2992
+ int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst ; int64_t stride_row_x; int64_t ncols_y ; int64_t nrows_dst;
2993
2993
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
2994
2994
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
2995
2995
bool use_stream_k;
@@ -3025,8 +3025,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3025
3025
}
3026
3026
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3027
3027
3028
- const int nty = (args.nrows_x + mmq_y - 1 ) / mmq_y;
3029
- const int ntx = (args.ncols_y + mmq_x - 1 ) / mmq_x;
3028
+ const int nty = (args.nrows_x + mmq_y - 1 ) / mmq_y;
3029
+ const int ntx = (args.ncols_dst + mmq_x - 1 ) / mmq_x;
3030
3030
const int ntzw = args.nchannels_y * args.nsamples_y ;
3031
3031
const dim3 block_nums_xy_tiling (nty, ntx, ntzw);
3032
3032
@@ -3040,14 +3040,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3040
3040
constexpr bool need_check = false ;
3041
3041
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3042
3042
(args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , nullptr ,
3043
- args.ncols_x , args.nrows_x , args.ncols_y , args.stride_row_x , args.nrows_dst ,
3043
+ args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args. ncols_y , args.nrows_dst ,
3044
3044
channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
3045
3045
sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
3046
3046
} else {
3047
3047
constexpr bool need_check = true ;
3048
3048
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3049
3049
(args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , nullptr ,
3050
- args.ncols_x , args.nrows_x , args.ncols_y , args.stride_row_x , args.nrows_dst ,
3050
+ args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args. ncols_y , args.nrows_dst ,
3051
3051
channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
3052
3052
sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
3053
3053
}
@@ -3068,7 +3068,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3068
3068
3069
3069
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3070
3070
(args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr ,
3071
- args.ncols_x , args.nrows_x , args.ncols_y , args.stride_row_x , args.nrows_dst ,
3071
+ args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args. ncols_y , args.nrows_dst ,
3072
3072
channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
3073
3073
sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
3074
3074
@@ -3077,14 +3077,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3077
3077
}
3078
3078
3079
3079
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0 , stream>>>
3080
- (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_y ,
3080
+ (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_dst ,
3081
3081
args.nrows_dst , args.nchannels_y , args.stride_channel_dst , args.nsamples_y , args.stride_sample_dst );
3082
3082
} else {
3083
3083
constexpr bool need_check = true ;
3084
3084
3085
3085
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3086
3086
(args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr ,
3087
- args.ncols_x , args.nrows_x , args.ncols_y , args.stride_row_x , args.nrows_dst ,
3087
+ args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args. ncols_y , args.nrows_dst ,
3088
3088
channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
3089
3089
sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
3090
3090
@@ -3093,7 +3093,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3093
3093
}
3094
3094
3095
3095
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0 , stream>>>
3096
- (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_y ,
3096
+ (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_dst ,
3097
3097
args.nrows_dst , args.nchannels_y , args.stride_channel_dst , args.nsamples_y , args.stride_sample_dst );
3098
3098
}
3099
3099
}
0 commit comments