-
Notifications
You must be signed in to change notification settings - Fork 12.4k
metal: SSM_SCAN performance #14743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
metal: SSM_SCAN performance #14743
Conversation
This may not be necessary, but it more closely mirrors the CUDA kernel Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@compilade I'd particularly love your feedback on this! |
It looks like the use of
I see that in |
26524d0
to
8d5a25d
Compare
Adding a version check did not do the trick: https://github.com/ggml-org/llama.cpp/actions/runs/16354752616/job/46210213747?pr=14743#step:6:27203 |
I'm able to repro the failure on my old M1 Max 64GB MacBook Pro (macOS 14.4.1). Support printouts look like:
This indicates that there's some difference between |
@@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group( | |||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; | |||
const float x_dt = x[0] * dt_soft_plus; | |||
const float dA = exp(dt_soft_plus * A[0]); | |||
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one appears to be unnecessary
Description
This is an attempt to improve the overall performance of the
mamba2
implementation ofSSM_SCAN
formetal
. I'm specifically interested in improving performance for Granite Four, but will hopefully achieve a nice speedup for other models as well.Changes
mamba2
clauses of the SSM_SCAN case forggml_metal_encode_node
, launch the kernel withthreadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)
(d_state
threads) and a shared memory buffer of size32 * sizeof(float)
(SIMD size)kernel_ssm_scan_f32_group
, remove the loop overnc
(d_state
) and instead usesimd_sum
to perform the finaly
calculation.Testing
All testing was done on my M3 Max 64GB MacBook Pro running macOS 15.5.
Validity testing
To ensure correctness, I used
test-backend-ops
:output
Performance Testing
To test the performance improvements, I used
llama-batched-bench
. I ran with a baseline onmaster
(01612b
) and comparison against raw CPU (-ngl 0
)Metal (baseline
01612b
)Metal (with changes)
CPU
Discussion
This is my very first forray into kernel implementation, so there are very likely problems with this implementation that an experienced eye will catch. The first part that I would love eyes on is the implementation of the sequential
simd_sum
calls, the use ofthreadgroup_barrier
and the sizing of the shared memory. This was guesswork based on the implementation of kernel_sum_rows, and it passestest-backend-ops
, but I don't feel 100% solid about it.