Skip to content

Commit 15a28ec

Browse files
CUDA: fix --split-mode row for MMQ (ggml-org#13323)
1 parent a7366fa commit 15a28ec

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

ggml/src/ggml-cuda/mmq.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ void ggml_cuda_mul_mat_q(
128128

129129
const mmq_args args = {
130130
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
131-
ne00, ne01, ne1, s01, s1,
131+
ne00, ne01, ne1, s01, ne11, s1,
132132
ne02, ne12, s02, s12, s2,
133133
ne03, ne13, s03, s13, s3,
134134
use_stream_k};
@@ -212,7 +212,7 @@ void ggml_cuda_mul_mat_q(
212212
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
213213
const mmq_args args = {
214214
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
215-
ne00, ne01, ne_get_rows, s01, s1,
215+
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
216216
ne02, ne02, s02, s12, s2,
217217
ne03, ne13, s03, s13, s3,
218218
use_stream_k};
@@ -251,7 +251,7 @@ void ggml_cuda_op_mul_mat_q(
251251
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
252252
const mmq_args args = {
253253
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
254-
ne00, row_diff, src1_ncols, stride01, nrows_dst,
254+
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
255255
1, 1, 0, 0, 0,
256256
1, 1, 0, 0, 0,
257257
use_stream_k};

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2522,7 +2522,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
25222522
static __device__ __forceinline__ void mul_mat_q_process_tile(
25232523
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
25242524
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525-
const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
2525+
const int nrows_x, const int stride_row_x, const int ncols_y, const int stride_col_dst,
25262526
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
25272527

25282528
constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2606,7 +2606,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
26062606
static __global__ void mul_mat_q(
26072607
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
26082608
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2609-
const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
2609+
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
26102610
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
26112611
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
26122612

@@ -2619,8 +2619,8 @@ static __global__ void mul_mat_q(
26192619
constexpr int qk = ggml_cuda_type_traits<type>::qk;
26202620
constexpr int mmq_y = get_mmq_y_device();
26212621

2622-
const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x
2623-
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
2622+
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
2623+
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
26242624

26252625
// Initialize the ids for writing back data with just the index.
26262626
// For regular matrix multiplications this is never changed.
@@ -2648,8 +2648,8 @@ static __global__ void mul_mat_q(
26482648

26492649
// Defaults for regular matrix multiplication:
26502650
int col_low = 0;
2651-
int col_high = ncols_y;
2652-
int col_diff = ncols_y;
2651+
int col_high = ncols_dst;
2652+
int col_diff = ncols_dst;
26532653
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
26542654
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
26552655

@@ -2689,7 +2689,7 @@ static __global__ void mul_mat_q(
26892689

26902690
constexpr bool fixup = false;
26912691
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2692-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2692+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
26932693
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
26942694
return;
26952695
}
@@ -2720,8 +2720,8 @@ static __global__ void mul_mat_q(
27202720

27212721
// Defaults for regular matrix multiplication:
27222722
int col_low = 0;
2723-
int col_high = ncols_y;
2724-
int col_diff = ncols_y;
2723+
int col_high = ncols_dst;
2724+
int col_diff = ncols_dst;
27252725
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
27262726
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
27272727

@@ -2767,7 +2767,7 @@ static __global__ void mul_mat_q(
27672767

27682768
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
27692769
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2770-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2770+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
27712771
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
27722772

27732773
kbc += blocks_per_ne00;
@@ -2792,8 +2792,8 @@ static __global__ void mul_mat_q(
27922792

27932793
// Defaults for regular matrix multiplication:
27942794
int col_low = 0;
2795-
int col_high = ncols_y;
2796-
int col_diff = ncols_y;
2795+
int col_high = ncols_dst;
2796+
int col_diff = ncols_dst;
27972797
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
27982798
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
27992799

@@ -2834,15 +2834,15 @@ static __global__ void mul_mat_q(
28342834

28352835
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
28362836
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2837-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2837+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
28382838
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
28392839
}
28402840

28412841

28422842
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
28432843
static __global__ void mul_mat_q_stream_k_fixup(
28442844
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2845-
const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst,
2845+
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
28462846
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
28472847
constexpr int mmq_y = get_mmq_y_device();
28482848
constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2851,8 +2851,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
28512851

28522852
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
28532853

2854-
const int ntx = (ncols_y + mmq_x - 1) / mmq_x;
2855-
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
2854+
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
2855+
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
28562856

28572857
const int bidx0 = blockIdx.x;
28582858

@@ -2925,8 +2925,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
29252925
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
29262926
dst += offset_dst;
29272927

2928-
const int i_max = nrows_x - it*mmq_y - 1;
2929-
const int j_max = ncols_y - jt*mmq_x - 1;
2928+
const int i_max = nrows_x - it*mmq_y - 1;
2929+
const int j_max = ncols_dst - jt*mmq_x - 1;
29302930

29312931
#pragma unroll
29322932
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -2989,7 +2989,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
29892989

29902990
struct mmq_args {
29912991
const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
2992-
int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst;
2992+
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
29932993
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
29942994
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
29952995
bool use_stream_k;
@@ -3025,8 +3025,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30253025
}
30263026
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
30273027

3028-
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029-
const int ntx = (args.ncols_y + mmq_x - 1) / mmq_x;
3028+
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029+
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
30303030
const int ntzw = args.nchannels_y * args.nsamples_y;
30313031
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
30323032

@@ -3040,14 +3040,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30403040
constexpr bool need_check = false;
30413041
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
30423042
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3043-
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3043+
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
30443044
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
30453045
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
30463046
} else {
30473047
constexpr bool need_check = true;
30483048
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
30493049
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3050-
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3050+
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
30513051
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
30523052
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
30533053
}
@@ -3068,7 +3068,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30683068

30693069
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
30703070
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3071-
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3071+
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
30723072
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
30733073
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
30743074

@@ -3077,14 +3077,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30773077
}
30783078

30793079
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3080-
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
3080+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
30813081
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
30823082
} else {
30833083
constexpr bool need_check = true;
30843084

30853085
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
30863086
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3087-
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3087+
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
30883088
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
30893089
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
30903090

@@ -3093,7 +3093,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30933093
}
30943094

30953095
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3096-
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
3096+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
30973097
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
30983098
}
30993099
}

0 commit comments

Comments
 (0)