File tree Expand file tree Collapse file tree 2 files changed +6
-8
lines changed Expand file tree Collapse file tree 2 files changed +6
-8
lines changed Original file line number Diff line number Diff line change 3
3
"""Attention layer with FlashAttention."""
4
4
from collections import defaultdict
5
5
from dataclasses import dataclass
6
- from typing import TYPE_CHECKING , Any , Optional
6
+ from typing import Any , Optional
7
7
8
8
import torch
9
9
from torch .nn .attention .flex_attention import (BlockMask , _mask_mod_signature ,
23
23
24
24
logger = init_logger (__name__ )
25
25
26
- if TYPE_CHECKING :
27
- pass
28
-
29
26
create_block_mask_compiled = torch .compile (create_block_mask ,
30
27
fullgraph = True ,
31
28
mode = "reduce-overhead" )
Original file line number Diff line number Diff line change @@ -685,6 +685,9 @@ def _prepare_inputs(
685
685
for kv_cache_group_id , kv_cache_group_spec in enumerate (
686
686
self .kv_cache_config .kv_cache_groups ):
687
687
688
+ blk_table = self .input_batch .block_table [kv_cache_group_id ]
689
+ blk_table_tensor = blk_table .get_device_tensor ()[:num_reqs ]
690
+ slot_mapping = blk_table .slot_mapping [:total_num_scheduled_tokens ]
688
691
common_attn_metadata = CommonAttentionMetadata (
689
692
query_start_loc = self .query_start_loc [:num_reqs + 1 ],
690
693
query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
@@ -695,10 +698,8 @@ def _prepare_inputs(
695
698
num_reqs = num_reqs ,
696
699
num_actual_tokens = total_num_scheduled_tokens ,
697
700
max_query_len = max_num_scheduled_tokens ,
698
- block_table_tensor = self .input_batch .
699
- block_table [kv_cache_group_id ].get_device_tensor ()[:num_reqs ],
700
- slot_mapping = self .input_batch .block_table [kv_cache_group_id ].
701
- slot_mapping [:total_num_scheduled_tokens ],
701
+ block_table_tensor = blk_table_tensor ,
702
+ slot_mapping = slot_mapping ,
702
703
)
703
704
704
705
if self .speculative_config and \
You can’t perform that action at this time.
0 commit comments