Skip to content

[BugFix] Fix full cuda graph slot_mapping #21228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 19, 2025

Conversation

fhl2000
Copy link
Contributor

@fhl2000 fhl2000 commented Jul 19, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Fix a bug related to #20466 and #21196, where full cudagraph may capture the slot_mapping of insufficient length when max_num_seqs < max_capture_size. This could lead to weird output contents that may be irrelevant to the prompts.

Test Plan

No further test plan. It works fine on my local after the fix.

Test Result

(Optional) Documentation Update

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@fhl2000 fhl2000 changed the title [BugFix] Fix full cuda-graph slot_mapping [BugFix] Fix full cuda graph slot_mapping Jul 19, 2025
@mergify mergify bot added the v1 label Jul 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a critical bug in the CUDA graph capturing logic for slot_mapping. The change from slicing by num_reqs to num_tokens is accurate, as slot_mapping is a per-token data structure. This prevents potential memory corruption and incorrect model outputs when using full CUDA graph capturing. The fix is well-reasoned and I approve of the change.

Comment on lines 2044 to +2045
slot_mapping=self.input_batch.
block_table[kv_cache_group_id].slot_mapping[:num_reqs])
block_table[kv_cache_group_id].slot_mapping[:num_tokens])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This is a great catch and a critical fix. The slot_mapping tensor maps tokens to their slots in the KV cache, so its length must be proportional to the number of tokens, not the number of requests.

Previously, slicing with [:num_reqs] was incorrect, especially when num_tokens (from cudagraph_capture_sizes) was greater than max_num_seqs. In that scenario, num_reqs would be capped at max_num_seqs, leading to an undersized slot_mapping being captured in the CUDA graph. This could cause memory corruption or incorrect outputs during graph replay, as you've pointed out.

Using [:num_tokens] correctly sizes the slot_mapping for the dummy run, ensuring the CUDA graph is captured with the correct metadata. This prevents the weird output artifacts and ensures correctness.

@fhl2000
Copy link
Contributor Author

fhl2000 commented Jul 19, 2025

CC @mgoin

@mgoin mgoin added this to the v0.10.0 milestone Jul 19, 2025
@mgoin mgoin added the bug Something isn't working label Jul 19, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find, let's get this in asap!

@mgoin mgoin enabled auto-merge (squash) July 19, 2025 16:11
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 19, 2025
@simon-mo simon-mo disabled auto-merge July 19, 2025 21:13
@simon-mo simon-mo merged commit 2e8cbb5 into vllm-project:main Jul 19, 2025
62 of 66 checks passed
@fhl2000 fhl2000 deleted the fix_full_cuda_graph_IMA branch July 20, 2025 04:44
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants