-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
SM100 Cutlass MLA decode with unrestricted num_heads (< 128) for DeepSeek TP #20769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @alexm-redhat, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a highly optimized Cutlass MLA (Multi-Layer Attention) decode kernel tailored for NVIDIA's SM100 (Blackwell) architecture. The core objective is to eliminate the previous head_dim==128
constraint, enabling greater flexibility in head dimensions and facilitating more efficient Tensor Parallelism. The changes encompass new CUDA kernels, a dedicated reduction kernel for split KV scenarios, and a sophisticated tile scheduler, all seamlessly integrated into the VLLM attention backend system, including conditional integration with FlashInfer for prefill operations.
Highlights
- New SM100 Optimized Kernel: Introduced a new Cutlass MLA decode kernel specifically designed for NVIDIA's SM100 (Blackwell) architecture, leveraging advanced features like TMA (Tensor Memory Accelerator) and warp specialization for improved performance.
- Unrestricted Head Dimension: Removed the previous
head_dim==128
restriction, allowing for more flexible head dimensions (including those less than 128). This is achieved by internally padding smaller head dimensions to 128 for kernel execution, enabling broader applicability. - Enhanced Tensor Parallelism Support: The newfound flexibility in
head_dim
directly improves the efficiency and applicability of Tensor Parallelism (TP) configurations. - Significant Performance Gains: Demonstrated up to a 10% performance improvement for DeepSeek V2 on B200 GPUs, showcasing the practical benefits of the new kernel's optimizations.
- Split KV and Reduction Logic: Implemented a robust mechanism for splitting Key-Value (KV) computation across multiple blocks/SMs and a corresponding reduction kernel to accurately combine results, essential for handling diverse
head_dim
values. - FlashInfer Integration for Prefill: Integrated with FlashInfer for optimized prefill operations on SM100a architectures, leveraging its capabilities for efficient attention computation.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new attention backend for SM100 GPUs using CUTLASS MLA kernels, which is a significant feature enabling unrestricted head dimensions. The changes are comprehensive, including new CUDA/C++ kernels, Python backend implementation, and integration into the vLLM framework. The code appears well-structured and follows modern high-performance computing practices. My review has identified a potential issue with workspace management in the Python backend implementation that should be addressed to ensure correctness and robustness.
bs = 128 | ||
PAGE_SIZE = 128 | ||
max_seqlen_pad = 131072 | ||
sm_count = 0 | ||
|
||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size( | ||
max_seqlen_pad * PAGE_SIZE, bs, sm_count, num_kv_splits=1) | ||
|
||
# workspace_size = 1024 * 5 * 1024 * 1024 | ||
# self.workspace = torch.empty(workspace_size, | ||
# device="cuda", | ||
# dtype=torch.uint8) | ||
|
||
self.workspace = g_workspace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a global variable g_workspace
(defined at lines 29-30) for the workspace buffer is not thread-safe and can lead to memory management issues. It's allocated once when the module is imported and never freed, and it would be shared across all instances of Sm100CutlassMLAImpl
, which is not safe.
The workspace should be allocated per-instance inside __init__
. The dynamic size calculation that is currently commented out is a good approach. Also, torch.empty
should be used instead of torch.zeros
for efficiency, as zero-initialization is not necessary for a workspace buffer.
Here is a suggested implementation. This change also implies that the global g_workspace
at lines 29-30 should be removed.
bs = 128 | |
PAGE_SIZE = 128 | |
max_seqlen_pad = 131072 | |
sm_count = 0 | |
workspace_size = ops.sm100_cutlass_mla_get_workspace_size( | |
max_seqlen_pad * PAGE_SIZE, bs, sm_count, num_kv_splits=1) | |
# workspace_size = 1024 * 5 * 1024 * 1024 | |
# self.workspace = torch.empty(workspace_size, | |
# device="cuda", | |
# dtype=torch.uint8) | |
self.workspace = g_workspace | |
# TODO: The workspace size calculation depends on parameters that might | |
# not be fixed. This should be revisited to ensure the workspace is | |
# always large enough. For now, using a large pre-calculated value. | |
bs = 128 | |
PAGE_SIZE = 128 | |
max_seqlen_pad = 131072 | |
sm_count = 0 | |
workspace_size = ops.sm100_cutlass_mla_get_workspace_size( | |
max_seqlen_pad * PAGE_SIZE, bs, sm_count, num_kv_splits=1) | |
self.workspace = torch.empty(workspace_size, | |
device="cuda", | |
dtype=torch.uint8) |
31ac2ac
to
528d3c8
Compare
8a55a13
to
2d6fce1
Compare
9c9f633
to
1bcb4ef
Compare
This pull request has merge conflicts that must be resolved before it can be |
1bcb4ef
to
32e4481
Compare
70870d9
to
798037c
Compare
6082478
to
cdad4f4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable to me now, thanks for combining the backends. I still hope we can remove the old kernel ASAP to prevent this duplication, as that is my only issue with the PR.
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
cdad4f4
to
44da059
Compare
def ensure_size(self, attn_metadata: MLACommonMetadata, | ||
num_kv_splits: int): | ||
batch_size = attn_metadata.num_reqs | ||
max_seq_len = attn_metadata.max_query_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this overestimates the workspace since num_reqs
and max_query_len
is for the whole batch and not just the decode portion; we should use max_query_len = 1
and batch_size = attn_medata.num_decodes
nit: max_seq_len
generally refers to the max KV-length but here its referring to max_query_len
we should try to keep the naming consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I will fix it in the follow up PR
@@ -716,6 +719,8 @@ def build(self, common_prefix_len: int, | |||
) | |||
|
|||
attn_metadata = self.metadata_cls( | |||
num_reqs=common_attn_metadata.num_reqs, | |||
max_query_len=common_attn_metadata.max_query_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I dont think these batch wise stats are needed see: https://github.com/vllm-project/vllm/pull/20769/files#r2206197677
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for being slow on the review! Left some comments we may want to fix up in a future PR
…Seek TP (vllm-project#20769) Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
After building the lastest main branch on A800 device, I encounterd the following error:
It looks like this PR is the root cause cc @mgoin @LucasWilkinson |
I'm seeing this as well, I'm on an H100. Since there are multiple independent reports, could we revert this please? @alexm-redhat @mgoin @LucasWilkinson |
How are you triggering this? I just installed from main on H100 with
and running Are you sure you have updated your binary? |
testing a fix now: #21020 |
This PR ports SGLANG changes to remove num_heads==128 restriction in cutlass mla decode kernel to vllm.
Here are performance results on deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct with a single B200 (this model has < 128 num_heads so could not run before this PR). 65% TPOT improvement for b1 and around 8-10% for larger batch sizes.
Here are some results on DeepSeek V2 on 4xB200, we can see 10% improvement for TPOT for b1 and b64.