Skip to content

metal: SSM_SCAN performance #14743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

gabe-l-hart
Copy link
Collaborator

Description

This is an attempt to improve the overall performance of the mamba2 implementation of SSM_SCAN for metal. I'm specifically interested in improving performance for Granite Four, but will hopefully achieve a nice speedup for other models as well.

Changes

  • In the mamba2 clauses of the SSM_SCAN case for ggml_metal_encode_node, launch the kernel with threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1) (d_state threads) and a shared memory buffer of size 32 * sizeof(float) (SIMD size)
  • In kernel_ssm_scan_f32_group, remove the loop over nc (d_state) and instead use simd_sum to perform the final y calculation.

Testing

All testing was done on my M3 Max 64GB MacBook Pro running macOS 15.5.

Validity testing

To ensure correctness, I used test-backend-ops:

./bin/test-backend-ops -o SSM_SCAN
output
Backend 1/3: Metal
  Device description: Apple M3 Max
  Device memory: 49152 MB (49146 MB free)

  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4): OK
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4): OK
  SSM_SCAN(type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4): OK
  6551/6551 tests passed
  Backend Metal: OK
ggml_metal_free: deallocating
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
Backend 2/3: BLAS
  Device description: Accelerate
  Device memory: 0 MB (0 MB free)

  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4): not supported [BLAS] 
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4): not supported [BLAS] 
  SSM_SCAN(type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4): not supported [BLAS] 
  6551/6551 tests passed
  Backend BLAS: OK
Backend 3/3: CPU
  Skipping CPU backend
3/3 backends passed
OK

Performance Testing

To test the performance improvements, I used llama-batched-bench. I ran with a baseline on master (01612b) and comparison against raw CPU (-ngl 0)

# Test 256/256/1 for simple usage
./bin/llama-batched-bench -m ~/models/granite-4.0-tiny-preview/ggml-model-Q4_K_M.gguf -c 131072 -b 2048 -ub 512 -npp 256 -ntg 256 -npl 1 -ngl 99

# Test 2560/256/[1,2] for more realistic usage with context
./bin/llama-batched-bench -m ~/models/granite-4.0-tiny-preview/ggml-model-Q4_K_M.gguf -c 131072 -b 2048 -ub 512 -npp 2560 -ntg 256 -npl 1,2 -ngl 99

Metal (baseline 01612b)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.852 300.60 6.391 40.06 7.243 70.69
2560 256 1 2816 8.399 304.80 6.314 40.54 14.713 191.39
2560 256 2 5632 16.503 310.25 11.220 45.63 27.723 203.15

Metal (with changes)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.296 863.45 5.839 43.84 6.136 83.44
2560 256 1 2816 2.690 951.77 5.780 44.29 8.470 332.49
2560 256 2 5632 5.398 948.50 9.754 52.49 15.152 371.71

CPU

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.876 292.33 3.509 72.95 4.385 116.76
2560 256 1 2816 7.263 352.46 3.652 70.11 10.915 258.00
2560 256 2 5632 16.695 306.68 5.105 100.30 21.800 258.35

Discussion

This is my very first forray into kernel implementation, so there are very likely problems with this implementation that an experienced eye will catch. The first part that I would love eyes on is the implementation of the sequential simd_sum calls, the use of threadgroup_barrier and the sizing of the shared memory. This was guesswork based on the implementation of kernel_sum_rows, and it passes test-backend-ops, but I don't feel 100% solid about it.

This may not be necessary, but it more closely mirrors the CUDA kernel

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This is a first attempt at optimizing the metal kernel. The changes here
are:

- Launch the kernel with a thread group of size d_state
- Use simd groups and shared memory to do the summation for the y
  computation

When tested with G4 tiny preview, this shows roughly a 3x speedup on
prefill and 15% speedup on decode.

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@gabe-l-hart
Copy link
Collaborator Author

@compilade I'd particularly love your feedback on this!

@github-actions github-actions bot added ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Jul 17, 2025
@gabe-l-hart gabe-l-hart changed the title metail: SSM_SCAN performance metal: SSM_SCAN performance Jul 17, 2025
@gabe-l-hart
Copy link
Collaborator Author

gabe-l-hart commented Jul 17, 2025

It looks like the use of simd_sum means that this now requires MTLGPUFamilyApple7 and the CI runner is on MTLGPUFamilyApple5. My laptop where this is passing shows:

ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)

I see that in ggml-metal.m, sum_rows, which also uses simd_sum in the implementation does not do any version checking in its enablement flag, but I'm wondering if using simd_sum in SSM_SCAN should do a version check somehow?

@gabe-l-hart
Copy link
Collaborator Author

@gabe-l-hart
Copy link
Collaborator Author

I'm able to repro the failure on my old M1 Max 64GB MacBook Pro (macOS 14.4.1). Support printouts look like:

ggml_metal_init: GPU family: MTLGPUFamilyApple7  (1007)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)

This indicates that there's some difference between MTLGPUFamilyApple7 and MTLGPUFamilyApple9 showing up.

@@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group(
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);

threadgroup_barrier(mem_flags::mem_threadgroup);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one appears to be unnecessary

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant