You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to pass variable split_k via Arguments.ptr_split_kv on-device. I need to also set max value of split k Arguments.split_kv. The issue is that I need to keep max value of split k static (like 16 or so), but as soon as I set it higher than 1 perf degrades a lot (I pass only 1 to ptr_split_kv). This makes ptr_split_kv on-device not practically usable.
Another data point. Perf degrades 2x if I just split by 4x. This is really not expected as it's not doing 2x work (yes, there is some overhead, but it should be much smaller).
Thank you for your bug report. I investigated this a little.
When you enable split-k, what effectively happens is that B' = B * split, K' = K * split, ElementOut' = ElementAcc, and a second kernel runs that reads B' output buffers in ElementAcc, and writes B output buffers in ElementOut.
Based on that, I measured the achieved bandwidth of b=512 k=2048 split=1, b=2048 k=512 split=1 (as a proxy for kernel-only split-k), and b=512 k=2048 split=4. All of these achieve similar bandwidth.
The number printed is based on an "ideal" kernel, i.e. does not consider that traffic.
To get the algorithmic view, use algorithmic_bw = ideal_bw / ideal_bytes * (ideal_bytes + 2 * B * H * D_latent * split * ElementAcc) ~ ideal_bw * 1.8 in your case.
Based on that, it makes sense that perf is 2x worse, given that you transfer ~1.8x more bytes.
Nevertheless, I think we could still improve performance for short sequences. So far, the sample is mainly focused on demonstrating how to get good mainloop performance.
Separately, I am curious about your issue with ptr_split_kv. Could you provide an example input, in terms of the contents of ptr_split_kv and ptr_seqlen, that demonstrates the issue you are facing?
We could then investigate ways to speed this up.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
Describe the bug
I want to pass variable split_k via Arguments.ptr_split_kv on-device. I need to also set max value of split k Arguments.split_kv. The issue is that I need to keep max value of split k static (like 16 or so), but as soon as I set it higher than 1 perf degrades a lot (I pass only 1 to ptr_split_kv). This makes ptr_split_kv on-device not practically usable.
Another data point. Perf degrades 2x if I just split by 4x. This is really not expected as it's not doing 2x work (yes, there is some overhead, but it should be much smaller).
Steps/Code to reproduce bug
See description for repro.
Expected behavior
I expect perf not to degrade significantly when max split k is increased.
Environment details (please complete the following information):
B200
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: