diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b6a06b17bca2..7057023d6684 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -59,11 +59,6 @@ class CommonAttentionMetadata: block_table_tensor: torch.Tensor slot_mapping: torch.Tensor - def __post_init__(self): - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - self.slot_mapping[self.num_actual_tokens:].fill_(-1) - M = TypeVar("M") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c3eeb6c2e390..0dd2b78e0a9b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -684,7 +684,7 @@ def _prepare_inputs( self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache + # Fill unused with 0 for full cuda graph mode. self.seq_lens[num_reqs:].fill_(0) # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that @@ -704,6 +704,11 @@ def _prepare_inputs( blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + common_attn_metadata = CommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],