Skip to content

SM100 Cutlass MLA decode with unrestricted num_heads (< 128) for DeepSeek TP #20769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 15, 2025

Conversation

alexm-redhat
Copy link
Collaborator

@alexm-redhat alexm-redhat commented Jul 10, 2025

This PR ports SGLANG changes to remove num_heads==128 restriction in cutlass mla decode kernel to vllm.

Here are performance results on deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct with a single B200 (this model has < 128 num_heads so could not run before this PR). 65% TPOT improvement for b1 and around 8-10% for larger batch sizes.

image

Here are some results on DeepSeek V2 on 4xB200, we can see 10% improvement for TPOT for b1 and b64.

image

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@alexm-redhat alexm-redhat requested a review from mgoin July 10, 2025 16:18
@mergify mergify bot added documentation Improvements or additions to documentation ci/build v1 labels Jul 10, 2025
Copy link

mergify bot commented Jul 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @alexm-redhat, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized Cutlass MLA (Multi-Layer Attention) decode kernel tailored for NVIDIA's SM100 (Blackwell) architecture. The core objective is to eliminate the previous head_dim==128 constraint, enabling greater flexibility in head dimensions and facilitating more efficient Tensor Parallelism. The changes encompass new CUDA kernels, a dedicated reduction kernel for split KV scenarios, and a sophisticated tile scheduler, all seamlessly integrated into the VLLM attention backend system, including conditional integration with FlashInfer for prefill operations.

Highlights

  • New SM100 Optimized Kernel: Introduced a new Cutlass MLA decode kernel specifically designed for NVIDIA's SM100 (Blackwell) architecture, leveraging advanced features like TMA (Tensor Memory Accelerator) and warp specialization for improved performance.
  • Unrestricted Head Dimension: Removed the previous head_dim==128 restriction, allowing for more flexible head dimensions (including those less than 128). This is achieved by internally padding smaller head dimensions to 128 for kernel execution, enabling broader applicability.
  • Enhanced Tensor Parallelism Support: The newfound flexibility in head_dim directly improves the efficiency and applicability of Tensor Parallelism (TP) configurations.
  • Significant Performance Gains: Demonstrated up to a 10% performance improvement for DeepSeek V2 on B200 GPUs, showcasing the practical benefits of the new kernel's optimizations.
  • Split KV and Reduction Logic: Implemented a robust mechanism for splitting Key-Value (KV) computation across multiple blocks/SMs and a corresponding reduction kernel to accurately combine results, essential for handling diverse head_dim values.
  • FlashInfer Integration for Prefill: Integrated with FlashInfer for optimized prefill operations on SM100a architectures, leveraging its capabilities for efficient attention computation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new attention backend for SM100 GPUs using CUTLASS MLA kernels, which is a significant feature enabling unrestricted head dimensions. The changes are comprehensive, including new CUDA/C++ kernels, Python backend implementation, and integration into the vLLM framework. The code appears well-structured and follows modern high-performance computing practices. My review has identified a potential issue with workspace management in the Python backend implementation that should be addressed to ensure correctness and robustness.

Comment on lines 74 to 91
bs = 128
PAGE_SIZE = 128
max_seqlen_pad = 131072
sm_count = 0

workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs, sm_count, num_kv_splits=1)

# workspace_size = 1024 * 5 * 1024 * 1024
# self.workspace = torch.empty(workspace_size,
# device="cuda",
# dtype=torch.uint8)

self.workspace = g_workspace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a global variable g_workspace (defined at lines 29-30) for the workspace buffer is not thread-safe and can lead to memory management issues. It's allocated once when the module is imported and never freed, and it would be shared across all instances of Sm100CutlassMLAImpl, which is not safe.

The workspace should be allocated per-instance inside __init__. The dynamic size calculation that is currently commented out is a good approach. Also, torch.empty should be used instead of torch.zeros for efficiency, as zero-initialization is not necessary for a workspace buffer.

Here is a suggested implementation. This change also implies that the global g_workspace at lines 29-30 should be removed.

Suggested change
bs = 128
PAGE_SIZE = 128
max_seqlen_pad = 131072
sm_count = 0
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs, sm_count, num_kv_splits=1)
# workspace_size = 1024 * 5 * 1024 * 1024
# self.workspace = torch.empty(workspace_size,
# device="cuda",
# dtype=torch.uint8)
self.workspace = g_workspace
# TODO: The workspace size calculation depends on parameters that might
# not be fixed. This should be revisited to ensure the workspace is
# always large enough. For now, using a large pre-calculated value.
bs = 128
PAGE_SIZE = 128
max_seqlen_pad = 131072
sm_count = 0
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs, sm_count, num_kv_splits=1)
self.workspace = torch.empty(workspace_size,
device="cuda",
dtype=torch.uint8)

@alexm-redhat alexm-redhat force-pushed the mla_fi_prefill_and_decode branch from 31ac2ac to 528d3c8 Compare July 10, 2025 16:44
@mergify mergify bot removed the needs-rebase label Jul 10, 2025
@alexm-redhat alexm-redhat force-pushed the mla_fi_prefill_and_decode branch 2 times, most recently from 8a55a13 to 2d6fce1 Compare July 10, 2025 17:19
@alexm-redhat alexm-redhat force-pushed the mla_fi_prefill_and_decode branch 4 times, most recently from 9c9f633 to 1bcb4ef Compare July 10, 2025 20:05
@alexm-redhat alexm-redhat changed the title Cutlass MLA decode with unrestricted head_dim (can be < 128) which allows TP as well Cutlass MLA decode with unrestricted num_heads (can be < 128) which allows TP as well Jul 10, 2025
Copy link

mergify bot commented Jul 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 11, 2025
@alexm-redhat alexm-redhat force-pushed the mla_fi_prefill_and_decode branch from 1bcb4ef to 32e4481 Compare July 11, 2025 17:39
@mergify mergify bot removed the needs-rebase label Jul 11, 2025
@alexm-redhat alexm-redhat force-pushed the mla_fi_prefill_and_decode branch 2 times, most recently from 70870d9 to 798037c Compare July 11, 2025 17:41
@alexm-redhat alexm-redhat self-assigned this Jul 11, 2025
@alexm-redhat alexm-redhat force-pushed the mla_fi_prefill_and_decode branch 4 times, most recently from 6082478 to cdad4f4 Compare July 14, 2025 17:55
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me now, thanks for combining the backends. I still hope we can remove the old kernel ASAP to prevent this duplication, as that is my only issue with the PR.

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
@alexm-redhat alexm-redhat force-pushed the mla_fi_prefill_and_decode branch from cdad4f4 to 44da059 Compare July 14, 2025 19:10
@mgoin mgoin changed the title Cutlass MLA decode with unrestricted num_heads (can be < 128) which allows TP as well SM100 Cutlass MLA decode with unrestricted num_heads (< 128) for DeepSeek TP Jul 14, 2025
@mgoin mgoin added deepseek Related to DeepSeek models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Jul 14, 2025
@alexm-redhat alexm-redhat enabled auto-merge (squash) July 14, 2025 19:37
@alexm-redhat alexm-redhat merged commit 8cdc371 into main Jul 15, 2025
101 checks passed
@alexm-redhat alexm-redhat deleted the mla_fi_prefill_and_decode branch July 15, 2025 01:06
def ensure_size(self, attn_metadata: MLACommonMetadata,
num_kv_splits: int):
batch_size = attn_metadata.num_reqs
max_seq_len = attn_metadata.max_query_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this overestimates the workspace since num_reqs and max_query_len is for the whole batch and not just the decode portion; we should use max_query_len = 1 and batch_size = attn_medata.num_decodes

nit: max_seq_len generally refers to the max KV-length but here its referring to max_query_len we should try to keep the naming consistent

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I will fix it in the follow up PR

@@ -716,6 +719,8 @@ def build(self, common_prefix_len: int,
)

attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I dont think these batch wise stats are needed see: https://github.com/vllm-project/vllm/pull/20769/files#r2206197677

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for being slow on the review! Left some comments we may want to fix up in a future PR

patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
…Seek TP (vllm-project#20769)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
@jeejeelee
Copy link
Collaborator

After building the lastest main branch on A800 device, I encounterd the following error:

  File "/vllm/vllm/platforms/cuda.py", line 18, in <module>
    import vllm._C  # noqa
ImportError: /vllm/vllm/_C.abi3.so: undefined symbol: _Z36sm100_cutlass_mla_get_workspace_sizellll

It looks like this PR is the root cause cc @mgoin @LucasWilkinson

@zou3519
Copy link
Collaborator

zou3519 commented Jul 15, 2025

After building the lastest main branch on A800 device, I encounterd the following error:

I'm seeing this as well, I'm on an H100. Since there are multiple independent reports, could we revert this please? @alexm-redhat @mgoin @LucasWilkinson

@mgoin
Copy link
Member

mgoin commented Jul 16, 2025

How are you triggering this? I just installed from main on H100 with

VLLM_USE_PRECOMPILED=1 uv pip install -U -e . --torch-backend=auto

and running python -c "import vllm._C" works fine.

Are you sure you have updated your binary?

@LucasWilkinson
Copy link
Collaborator

testing a fix now: #21020

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants