Skip to content

metal: SSM_SCAN performance #14743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ typedef struct {
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
int64_t s_off;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node(
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
Expand Down Expand Up @@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node(

if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
Expand Down
39 changes: 29 additions & 10 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32(
}

// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
Expand All @@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group(
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
uint3 ntg[[threads_per_threadgroup]]) {

const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
Expand All @@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group(
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;

const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
const int64_t s_off = args.s_off;

device const int32_t * ids = (device const int32_t *) src6;

Expand All @@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group(
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);

threadgroup_barrier(mem_flags::mem_threadgroup);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one appears to be unnecessary


float sumf = 0.0f;

for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
const int64_t i = tpitg.x + i1*nc;
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
sumf += state * C[tpitg.x];
s[i] = state;

sumf = simd_sum(sumf);

threadgroup_barrier(mem_flags::mem_threadgroup);

// Use the shared buffer to hold the sum of each simd group
if (tiisg == 0) {
shared[sgitg] = sumf;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// Sum the simd buckets
sumf = shared[tiisg];
sumf = simd_sum(sumf);

y[0] = sumf;

// recurse
Expand Down
Loading