Skip to content

[BugFix] Fix import error on non-blackwell machines #21020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 16, 2025

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Jul 16, 2025

FIX #20769 (comment)

Checked vllm serve runs when built on hopper
Checked can run VLLM_ATTENTION_BACKEND=CUTLASS_MLA_VLLM_V1 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,trust_remote_code=true --tasks gsm 8k --batch_size auto

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an import error on non-Blackwell machines by moving the operator implementation registration into a conditionally compiled file. The dispatch key for the workspace calculation function has been corrected to improve the robustness of the solution.

Comment on lines 274 to 277
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function sm100_cutlass_mla_get_workspace_size is a host-side function that calculates a workspace size and does not involve any GPU operations. Registering it only for the CUDA dispatch key is incorrect and can lead to runtime errors if called in a context where the PyTorch dispatcher selects a different backend. Host-only functions like this should be registered for the CPU dispatch key to ensure they can be called correctly regardless of the context.

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
  m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CPU, m) {
  m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
}

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Jul 16, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Jul 16, 2025

GTG:

Checked vllm serve runs when built on hopper
Checked can run VLLM_ATTENTION_BACKEND=CUTLASS_MLA_VLLM_V1 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,trust_remote_code=true --tasks gsm 8k --batch_size auto on Blackwell

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) July 16, 2025 02:20
Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested this PR locally, and it can fix #20769 (comment), thank you

@vllm-bot vllm-bot merged commit d31a647 into vllm-project:main Jul 16, 2025
87 of 91 checks passed
nadathurv pushed a commit to nadathurv/vllm that referenced this pull request Jul 16, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants