Skip to content

Commit da92d38

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 8518c02 commit da92d38

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

vllm/v1/attention/backends/flex_attention.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Attention layer with FlashAttention."""
44
from collections import defaultdict
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Any, Optional
6+
from typing import Any, Optional
77

88
import torch
99
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
@@ -23,9 +23,6 @@
2323

2424
logger = init_logger(__name__)
2525

26-
if TYPE_CHECKING:
27-
pass
28-
2926
create_block_mask_compiled = torch.compile(create_block_mask,
3027
fullgraph=True,
3128
mode="reduce-overhead")

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,9 @@ def _prepare_inputs(
684684
for kv_cache_group_id, kv_cache_group_spec in enumerate(
685685
self.kv_cache_config.kv_cache_groups):
686686

687+
blk_table = self.input_batch.block_table[kv_cache_group_id]
688+
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
689+
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
687690
common_attn_metadata = CommonAttentionMetadata(
688691
query_start_loc=self.query_start_loc[:num_reqs + 1],
689692
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
@@ -694,10 +697,8 @@ def _prepare_inputs(
694697
num_reqs=num_reqs,
695698
num_actual_tokens=total_num_scheduled_tokens,
696699
max_query_len=max_num_scheduled_tokens,
697-
block_table_tensor=self.input_batch.
698-
block_table[kv_cache_group_id].get_device_tensor()[:num_reqs],
699-
slot_mapping=self.input_batch.block_table[kv_cache_group_id].
700-
slot_mapping[:total_num_scheduled_tokens],
700+
block_table_tensor=blk_table_tensor,
701+
slot_mapping=slot_mapping,
701702
)
702703

703704
if self.speculative_config and \

0 commit comments

Comments
 (0)