Skip to content

[BugFix] Fix full cuda graph slot_mapping #21228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,7 +2042,7 @@ def _dummy_run(
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch.
block_table[kv_cache_group_id].slot_mapping[:num_reqs])
block_table[kv_cache_group_id].slot_mapping[:num_tokens])
Comment on lines 2044 to +2045
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This is a great catch and a critical fix. The slot_mapping tensor maps tokens to their slots in the KV cache, so its length must be proportional to the number of tokens, not the number of requests.

Previously, slicing with [:num_reqs] was incorrect, especially when num_tokens (from cudagraph_capture_sizes) was greater than max_num_seqs. In that scenario, num_reqs would be capped at max_num_seqs, leading to an undersized slot_mapping being captured in the CUDA graph. This could cause memory corruption or incorrect outputs during graph replay, as you've pointed out.

Using [:num_tokens] correctly sizes the slot_mapping for the dummy run, ensuring the CUDA graph is captured with the correct metadata. This prevents the weird output artifacts and ensures correctness.


attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
Expand Down