9
9
MLAAttentionImpl )
10
10
from vllm .attention .backends .utils import PAD_SLOT_ID
11
11
from vllm .config import get_current_vllm_config
12
+ from vllm .forward_context import get_forward_context
12
13
from vllm .model_executor .layers .linear import (LinearBase ,
13
14
UnquantizedLinearMethod )
14
15
from vllm .utils import cdiv , round_down
@@ -93,6 +94,8 @@ class ChunkedContextMetadata:
93
94
max_query_len : int
94
95
max_seq_lens : int
95
96
chunked_context : Optional [ChunkedContextMetadata ] = None
97
+ sin : torch .Tensor = None
98
+ cos : torch .Tensor = None
96
99
97
100
98
101
@dataclass
@@ -105,6 +108,8 @@ class AscendMLADecodeMetadata:
105
108
max_seq_lens : int
106
109
seq_lens_list : list [int ]
107
110
attn_mask : Optional [torch .Tensor ] = None
111
+ sin : torch .Tensor = None
112
+ cos : torch .Tensor = None
108
113
109
114
110
115
@dataclass
@@ -215,6 +220,9 @@ def __init__(self,
215
220
)
216
221
ascend_config = get_ascend_config ()
217
222
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
223
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
224
+ self .cos_cache = None
225
+ self .sin_cache = None
218
226
219
227
def reorder_batch (self , input_batch : "InputBatch" ,
220
228
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -333,13 +341,21 @@ def build_torchair_graph_dummy(
333
341
- 1 ,
334
342
dtype = torch .int32 ,
335
343
device = device )
344
+ sin = torch .ones (num_reqs , 1 , 1 , self .rope_dim ,
345
+ dtype = self .runner .dtype ,
346
+ device = device )
347
+ cos = torch .ones (num_reqs , 1 , 1 , self .rope_dim ,
348
+ dtype = self .runner .dtype ,
349
+ device = device )
336
350
decode_metadata = AscendMLADecodeMetadata (
337
351
input_positions = input_positions ,
338
352
block_table = block_table ,
339
353
seq_lens = seq_lens ,
340
354
seq_lens_list = seq_lens .tolist (),
341
355
max_seq_lens = 1 ,
342
- attn_mask = self .runner .spec_attn_mask )
356
+ attn_mask = self .runner .spec_attn_mask ,
357
+ sin = sin ,
358
+ cos = cos )
343
359
return self .metadata_cls ( # type: ignore
344
360
num_input_tokens = num_actual_tokens ,
345
361
num_actual_tokens = num_actual_tokens ,
@@ -388,6 +404,12 @@ def build(
388
404
max_query_len = query_lens .max ().item ()
389
405
max_seq_lens = seq_lens .max ().item ()
390
406
query_start_loc = common_attn_metadata .query_start_loc
407
+ if self .cos_cache is None :
408
+ self .cos_cache = self .runner .get_model ().model .layers [0 ].self_attn .rotary_emb .cos_cached
409
+ self .sin_cache = self .runner .get_model ().model .layers [0 ].self_attn .rotary_emb .sin_cached
410
+ if self .cos_cache .dtype != self .runner .dtype :
411
+ self .cos_cache = self .cos_cache .to (self .runner .dtype )
412
+ self .sin_cache = self .sin_cache .to (self .runner .dtype )
391
413
392
414
prefill_metadata = None
393
415
chunked_context_metadata = None
@@ -434,18 +456,22 @@ def build(
434
456
chunk_seq_lens = chunk_seq_lens ,
435
457
workspace = self .chunked_prefill_workspace ,
436
458
)
437
-
459
+ prefill_input_positions = input_positions [tokens_start :]
460
+ cos = self .cos_cache [prefill_input_positions ].unsqueeze (1 ).unsqueeze (2 )
461
+ sin = self .sin_cache [prefill_input_positions ].unsqueeze (1 ).unsqueeze (2 )
438
462
prefill_metadata = AscendMLAPrefillMetadata (
439
463
attn_mask = self .runner .attn_mask ,
440
464
query_lens = query_lens [tokens_start :],
441
465
seq_lens = seq_lens ,
442
466
context_lens = seq_lens [tokens_start :],
443
- input_positions = input_positions [ tokens_start :] ,
467
+ input_positions = prefill_input_positions ,
444
468
block_table = block_table [reqs_start :, ...],
445
469
max_query_len = max_query_len ,
446
470
max_seq_lens = max_seq_lens ,
447
471
query_start_loc = prefill_query_start_loc ,
448
472
chunked_context = chunked_context_metadata ,
473
+ sin = sin ,
474
+ cos = cos ,
449
475
)
450
476
451
477
decode_metadata = None
@@ -486,14 +512,18 @@ def build(
486
512
dtype = input_positions .dtype ,
487
513
device = input_positions .device )
488
514
input_positions = torch .cat ([input_positions , padding_0 ])
515
+ cos = self .cos_cache [input_positions ].unsqueeze (1 ).unsqueeze (2 )
516
+ sin = self .sin_cache [input_positions ].unsqueeze (1 ).unsqueeze (2 )
489
517
490
518
decode_metadata = AscendMLADecodeMetadata (
491
519
input_positions = input_positions ,
492
520
block_table = block_table ,
493
521
seq_lens = seq_lens ,
494
522
seq_lens_list = seq_lens .tolist (),
495
523
max_seq_lens = max_seq_lens ,
496
- attn_mask = self .runner .spec_attn_mask )
524
+ attn_mask = self .runner .spec_attn_mask ,
525
+ sin = sin ,
526
+ cos = cos )
497
527
498
528
return self .metadata_cls ( # type: ignore
499
529
num_actual_tokens = num_actual_tokens ,
@@ -1042,9 +1072,7 @@ def forward(
1042
1072
if attn_metadata is None :
1043
1073
# Profiling run.
1044
1074
return output
1045
- self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
1046
- AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
1047
- ]
1075
+ self .running_in_graph = get_forward_context ().running_in_graph
1048
1076
num_actual_toks = attn_metadata .num_actual_tokens
1049
1077
if k_pe is None and not self .running_in_graph :
1050
1078
kv_c , k_pe = self .kv_a_proj_with_mqa (
@@ -1082,15 +1110,8 @@ def forward(
1082
1110
decode_k_nope = None
1083
1111
assert attn_metadata .decode is not None
1084
1112
if self .running_in_graph :
1085
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1086
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1087
- dtype = decode_hs_or_q_c .dtype )
1088
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1089
- dtype = decode_hs_or_q_c .dtype )
1090
- cos = cos [attn_metadata .decode .input_positions ]
1091
- sin = sin [attn_metadata .decode .input_positions ]
1092
- cos = cos [:, None , None , :]
1093
- sin = sin [:, None , None , :]
1113
+ cos = attn_metadata .decode .cos
1114
+ sin = attn_metadata .decode .sin
1094
1115
# Without explicitly controlling the order, IndexByTensor operations
1095
1116
# would be placed after `matmul W_KV_T` hindering the overlapping of
1096
1117
# KvRmsNormRopeCache and SingleRope.
@@ -1125,15 +1146,8 @@ def forward(
1125
1146
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1126
1147
if self .torchair_graph_enabled :
1127
1148
num_tokens = prefill_hs_or_q_c .shape [0 ]
1128
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1129
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1130
- dtype = prefill_q_pe .dtype )
1131
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1132
- dtype = prefill_q_pe .dtype )
1133
- cos = cos [attn_metadata .prefill .input_positions ]
1134
- sin = sin [attn_metadata .prefill .input_positions ]
1135
- cos = cos [:, None , None , :]
1136
- sin = sin [:, None , None , :]
1149
+ cos = attn_metadata .prefill .cos
1150
+ sin = attn_metadata .prefill .sin
1137
1151
1138
1152
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1139
1153
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments