-
Notifications
You must be signed in to change notification settings - Fork 12.4k
CUDA: skip masked out KQ slices in mma FA kernel #14735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
CUDA: skip masked out KQ slices in mma FA kernel #14735
Conversation
I think it is not very difficult to provide the indices list (or some other meta information about the mask) without making breaking changes. What we can do is during the construction of the KQ mask, we can create a second tensor that contains non-masked block indices. This tensor could be attached to the flash attention op as an optional src: cur = ggml_flash_attn_ext(ctx, ...);
ggml_flash_attn_ext_add_idxs(cur, kq_idxs); The backend can decide wether to use these indices or not. We could also pass a bitmask if that would be more appropriate. Or even both. So from the PoV of On the backend side - if you are seeing significant benefits, I think it would be worth it. I think the existing FA Metal kernel can benefit from a bitmask. Not sure about the list of indices, but once it is provided, we can attempt to write an alternative kernel that uses it.
The first test that you did with |
Yes, that is how I would have implemented it as well.
For CUDA a bitmask would not work well because I don't know where in the bit mask I would have to set the start and end indices for evenly distributing data across streaming multiprocessors.
For a single sequence the preloading definitely does provide a benefit, the difference is something like 10% end-to-end speedup. Also I just noticed that I forgot that master and the PR are running different kernels for the RTX 4090, master is running the vector kernel, the PR is running the mma kernel. There is also a difference for the RTX 3090, what could be happening there is that due to the lower shared memory use for a single stage pipeline there is some benefit from higher occupancy.
master, RTX 4090:
PR, RTX 4090:
master, RTX 3090:
PR, RTX 3090:
I'll make a prototype for a kernel using a list of mask slices. |
I think I did the math in my head wrong, sorry. |
Try to make the list of indices to be compatible both with unified and split KV cache buffers. @slaren You mentioned recently that such modification might not be practical (#14363 (comment)). What do you think about the discussed approach above to pass the indices as an optional tensor that the backends can choose to ignore? |
Could the backend itself generate the list of indices from the KQ mask in a pre-processing step? |
I considered this but the problem is that I think that that will be inefficient to parallelize. With CUDA I can efficiently generate e.g. a bitmask for which mask slices are all I think I'll just make a prototype with an extended interface for |
This PR extends the mma-based CUDA FlashAttention kernel with logic for skipping fully masked-out KQ slices. However, this kernel makes use of asynchronous data loading to preload KV data so this is not straightforward: the mask and K data are being preloaded at the same time, if it turns out that the KQ slice should be skipped the K I/O is wasted, and also the GPU compute pipes are idling until the mask and K data for the next potential KQ slice can be fetched. It turns out that for the new batched inference setup it's faster not to preload any data at all but this is going to hamper the overall performance. With LLaMA 3 8b q4_0 and the command
I'm seeing net speedups of ~3% for my RTX 3090/4090, with
I'm seeing net speedups of ~10% (
LLAMA_SET_ROWS=1
set for both cases).I think that the approach of loading a tile of the mask and checking whether it's all
== -inf
is bad. I think I could write a much better kernel if I instead had a list of those 256x64 mask slices that are not all== -inf
. Then I could simply iterate over those indices and preload data without potentially wasting any I/O. It would also allow me to solve the problem where I cannot distribute work to streaming multiprocessors in an optimal way because I don't know ahead of time how the KQ slices needing compute are distributed; if one of the sequences is very long vs. the rest I would ideally assign more SMs to that sequence. For non-batched inference with a diagonal mask I could also skip a few KQ slices, it would reduce the amount of KV data to iterate over per token by half the physical batch size (probably not worthwhile on its own). Ideally one bit of the indices should be reserved to indicate whether all elements in the mask slice are== 0
, the mask for those slices will not need to be loaded (as of yet unclear whether that would make a meaningful difference). @ggerganov thoughts?