Skip to content

[BUG] Blackwell MLA perf for split-k #2222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
divchenko opened this issue Apr 4, 2025 · 3 comments
Open

[BUG] Blackwell MLA perf for split-k #2222

divchenko opened this issue Apr 4, 2025 · 3 comments
Labels

Comments

@divchenko
Copy link

Describe the bug

I want to pass variable split_k via Arguments.ptr_split_kv on-device. I need to also set max value of split k Arguments.split_kv. The issue is that I need to keep max value of split k static (like 16 or so), but as soon as I set it higher than 1 perf degrades a lot (I pass only 1 to ptr_split_kv). This makes ptr_split_kv on-device not practically usable.

Another data point. Perf degrades 2x if I just split by 4x. This is really not expected as it's not doing 2x work (yes, there is some overhead, but it should be much smaller).

$ /77_blackwell_mla_2sm_fp16 --b=512 --k=2048 --page=128 --verbose --split_kv=1
###### B 512 MLA H 128 D_rope 64 D_latent 512 Q 1 K 2048 Gen None Split 1 Gen None #SM 148
 [--] 128x128 fp16 persistent          : 1073.23 TFLOPS/s 4.13682 TB/s
       t=326.475 us, smem=210944b
 [--] 128x128 fp16 individual          : 1071.13 TFLOPS/s 4.12872 TB/s
       t=327.115 us, smem=210944b

$ /77_blackwell_mla_2sm_fp16 --b=512 --k=2048 --page=128 --verbose --split_kv=4
###### B 512 MLA H 128 D_rope 64 D_latent 512 Q 1 K 2048 Gen None Split 4 Gen None #SM 148
 [--] 128x128 fp16 persistent          : 547.292 TFLOPS/s 2.10956 TB/s
       t=640.213 us, smem=210944b
 [--] 128x128 fp16 individual          : 508.058 TFLOPS/s 1.95833 TB/s
       t=689.653 us, smem=210944b

Steps/Code to reproduce bug
See description for repro.

Expected behavior
I expect perf not to degrade significantly when max split k is increased.

Environment details (please complete the following information):
B200

Additional context
Add any other context about the problem here.

@divchenko divchenko added ? - Needs Triage bug Something isn't working labels Apr 4, 2025
@v0i0
Copy link

v0i0 commented Apr 7, 2025

Hey @divchenko,

Thank you for your bug report. I investigated this a little.

When you enable split-k, what effectively happens is that B' = B * split, K' = K * split, ElementOut' = ElementAcc, and a second kernel runs that reads B' output buffers in ElementAcc, and writes B output buffers in ElementOut.

Based on that, I measured the achieved bandwidth of b=512 k=2048 split=1, b=2048 k=512 split=1 (as a proxy for kernel-only split-k), and b=512 k=2048 split=4. All of these achieve similar bandwidth.

The number printed is based on an "ideal" kernel, i.e. does not consider that traffic.
To get the algorithmic view, use algorithmic_bw = ideal_bw / ideal_bytes * (ideal_bytes + 2 * B * H * D_latent * split * ElementAcc) ~ ideal_bw * 1.8 in your case.

Based on that, it makes sense that perf is 2x worse, given that you transfer ~1.8x more bytes.

Nevertheless, I think we could still improve performance for short sequences. So far, the sample is mainly focused on demonstrating how to get good mainloop performance.

Separately, I am curious about your issue with ptr_split_kv.
Could you provide an example input, in terms of the contents of ptr_split_kv and ptr_seqlen, that demonstrates the issue you are facing?
We could then investigate ways to speed this up.

@divchenko
Copy link
Author

Makes sense. That example might be of a red herring.

I'm running w/ batch size = 64, and fixed seq len = 2048

  • MainloopArguments.ptr_seq pointing to [2048, 2048, ..., 2048] , 64 total
  • Arguments.ptr_split_kv pointing to [1, 1, ..., 1] , 64 total

Now, I'm varying Arguments.split_kv:

  • split_kv = 1 results in runtime of 67us
  • split_kv = 4 - 86us
  • split_kv = 16 - 122us

Yes, split_kv effects grid size, but extra blocks should do almost no work, so I would imagine that penalty should be much smaller?

Copy link

github-actions bot commented May 8, 2025

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants