From ba74a2473045deda1be08a2cac42b8731644a6a8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 15 Jul 2025 17:03:38 -0600 Subject: [PATCH 1/2] feat: Add s_off as a parameter in the args struct This may not be necessary, but it more closely mirrors the CUDA kernel Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 752d55c216604..6f4427c55c50e 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -519,6 +519,7 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; + int64_t s_off; uint64_t nb01; uint64_t nb02; uint64_t nb03; From 8d5a25d3562617a95d800645c38d1b7b746bff5a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 15 Jul 2025 17:04:31 -0600 Subject: [PATCH 2/2] perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 4 ++- ggml/src/ggml-metal/ggml-metal.metal | 39 +++++++++++++++++++++------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 44ddc69d08f1c..de7d33046fc23 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node( /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(src1) * sizeof(float), /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, @@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node( if (ne30 == 1) { // Mamba-2 - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } else { GGML_ASSERT(d_inner == 1); [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 13235e2885241..ac2895b5164f6 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32( } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -// TODO: optimize (e.g. by parallelizing over d_state) kernel void kernel_ssm_scan_f32_group( device const void * src0, device const void * src1, @@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq @@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; @@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group( const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); + + threadgroup_barrier(mem_flags::mem_threadgroup); + float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * dA) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; + const int64_t i = tpitg.x + i1*nc; + const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); + sumf += state * C[tpitg.x]; + s[i] = state; + + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Use the shared buffer to hold the sum of each simd group + if (tiisg == 0) { + shared[sgitg] = sumf; } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Sum the simd buckets + sumf = shared[tiisg]; + sumf = simd_sum(sumf); + y[0] = sumf; // recurse