Skip to content

Commit ee9cbca

Browse files
LucasWilkinsonhj-mistral
authored andcommitted
[BugFix] Fix potential cuda-graph IMA (vllm-project#21196)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Himanshu Jaju <hj@mistral.ai>
1 parent ee2ae09 commit ee9cbca

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,6 @@ class CommonAttentionMetadata:
5959
block_table_tensor: torch.Tensor
6060
slot_mapping: torch.Tensor
6161

62-
def __post_init__(self):
63-
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
64-
# mode.
65-
self.slot_mapping[self.num_actual_tokens:].fill_(-1)
66-
6762

6863
M = TypeVar("M")
6964

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def _prepare_inputs(
684684
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
685685
non_blocking=True)
686686

687-
# Fill unused with -1. Needed for reshape_and_cache
687+
# Fill unused with 0 for full cuda graph mode.
688688
self.seq_lens[num_reqs:].fill_(0)
689689
# Note: pad query_start_loc to be non-decreasing, as kernels
690690
# like FlashAttention requires that
@@ -704,6 +704,11 @@ def _prepare_inputs(
704704
blk_table = self.input_batch.block_table[kv_cache_group_id]
705705
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
706706
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
707+
708+
# Fill unused with -1. Needed for reshape_and_cache in full cuda
709+
# graph mode.
710+
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
711+
707712
common_attn_metadata = CommonAttentionMetadata(
708713
query_start_loc=self.query_start_loc[:num_reqs + 1],
709714
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],

0 commit comments

Comments
 (0)