diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 752d55c216604..6f4427c55c50e 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -519,6 +519,7 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; + int64_t s_off; uint64_t nb01; uint64_t nb02; uint64_t nb03; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 44ddc69d08f1c..de7d33046fc23 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node( /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(src1) * sizeof(float), /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, @@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node( if (ne30 == 1) { // Mamba-2 - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } else { GGML_ASSERT(d_inner == 1); [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 13235e2885241..ac2895b5164f6 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32( } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -// TODO: optimize (e.g. by parallelizing over d_state) kernel void kernel_ssm_scan_f32_group( device const void * src0, device const void * src1, @@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq @@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; @@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group( const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); + + threadgroup_barrier(mem_flags::mem_threadgroup); + float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * dA) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; + const int64_t i = tpitg.x + i1*nc; + const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); + sumf += state * C[tpitg.x]; + s[i] = state; + + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Use the shared buffer to hold the sum of each simd group + if (tiisg == 0) { + shared[sgitg] = sumf; } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Sum the simd buckets + sumf = shared[tiisg]; + sumf = simd_sum(sumf); + y[0] = sumf; // recurse