File tree Expand file tree Collapse file tree 2 files changed +6
-8
lines changed Expand file tree Collapse file tree 2 files changed +6
-8
lines changed Original file line number Diff line number Diff line change 3
3
"""Attention layer with FlashAttention."""
4
4
from collections import defaultdict
5
5
from dataclasses import dataclass
6
- from typing import TYPE_CHECKING , Any , Optional
6
+ from typing import Any , Optional
7
7
8
8
import torch
9
9
from torch .nn .attention .flex_attention import (BlockMask , _mask_mod_signature ,
23
23
24
24
logger = init_logger (__name__ )
25
25
26
- if TYPE_CHECKING :
27
- pass
28
-
29
26
create_block_mask_compiled = torch .compile (create_block_mask ,
30
27
fullgraph = True ,
31
28
mode = "reduce-overhead" )
Original file line number Diff line number Diff line change @@ -684,6 +684,9 @@ def _prepare_inputs(
684
684
for kv_cache_group_id , kv_cache_group_spec in enumerate (
685
685
self .kv_cache_config .kv_cache_groups ):
686
686
687
+ blk_table = self .input_batch .block_table [kv_cache_group_id ]
688
+ blk_table_tensor = blk_table .get_device_tensor ()[:num_reqs ]
689
+ slot_mapping = blk_table .slot_mapping [:total_num_scheduled_tokens ]
687
690
common_attn_metadata = CommonAttentionMetadata (
688
691
query_start_loc = self .query_start_loc [:num_reqs + 1 ],
689
692
query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
@@ -694,10 +697,8 @@ def _prepare_inputs(
694
697
num_reqs = num_reqs ,
695
698
num_actual_tokens = total_num_scheduled_tokens ,
696
699
max_query_len = max_num_scheduled_tokens ,
697
- block_table_tensor = self .input_batch .
698
- block_table [kv_cache_group_id ].get_device_tensor ()[:num_reqs ],
699
- slot_mapping = self .input_batch .block_table [kv_cache_group_id ].
700
- slot_mapping [:total_num_scheduled_tokens ],
700
+ block_table_tensor = blk_table_tensor ,
701
+ slot_mapping = slot_mapping ,
701
702
)
702
703
703
704
if self .speculative_config and \
You can’t perform that action at this time.
0 commit comments