Skip to content

Rebalancing Metal threads workload in dot product kernel kernel_mul_mv_f16_f32_l4 #7522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
Expand Down Expand Up @@ -533,6 +534,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE, mul_mv_f16_f32_l4_large, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
Expand Down Expand Up @@ -1571,6 +1573,7 @@ static enum ggml_status ggml_metal_graph_compute(
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);

id<MTLComputePipelineState> pipeline = nil;
bool is_large = false;

// use custom matrix x vector kernel
switch (src0t) {
Expand All @@ -1588,7 +1591,12 @@ static enum ggml_status ggml_metal_graph_compute(
if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
if (ne01 > 128) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE].pipeline;
is_large = true;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
}
nrows = ne11;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
Expand Down Expand Up @@ -1778,7 +1786,11 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
const int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
if (is_large) {
[encoder dispatchThreadgroups:MTLSizeMake(ne01/32, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
}
} break;
Expand Down
58 changes: 58 additions & 0 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,64 @@ kernel void kernel_mul_mv_f16_f32_l4(
}
}

kernel void kernel_mul_mv_f16_f32_l4_large(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {

const int nrows = ne11;
const int64_t base_r0 = tgpig.x*32;
const int64_t im = tgpig.z;
threadgroup float partial_sums[32]; // Shared memory for partial sums for each SIMD group

const uint i12 = im%ne12;
const uint i13 = im/ne12;

for (int j = 0; j < 32; ++j) {
const int64_t r0 = base_r0 + j;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half4 * x4 = (device const half4 *) (src0 + offset0);

partial_sums[tiisg] = 0.0f;
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);

for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) partial_sums[tiisg] += (float) x4[i][k] * y4[i][k];
}

// Barrier to ensure all threads have written their partial sums
threadgroup_barrier(mem_flags::mem_threadgroup);
float sumf = simd_sum(partial_sums[tiisg]);
// Barrier to ensure reduction is complete before writing the result
threadgroup_barrier(mem_flags::mem_threadgroup);

if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
}

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
Expand Down