9
9
MLAAttentionImpl )
10
10
from vllm .attention .backends .utils import PAD_SLOT_ID
11
11
from vllm .config import get_current_vllm_config
12
+ from vllm .forward_context import get_forward_context
12
13
from vllm .model_executor .layers .linear import (LinearBase ,
13
14
UnquantizedLinearMethod )
14
15
from vllm .utils import cdiv , round_down
@@ -93,6 +94,8 @@ class ChunkedContextMetadata:
93
94
max_query_len : int
94
95
max_seq_lens : int
95
96
chunked_context : Optional [ChunkedContextMetadata ] = None
97
+ sin : torch .Tensor = None
98
+ cos : torch .Tensor = None
96
99
97
100
98
101
@dataclass
@@ -106,6 +109,8 @@ class AscendMLADecodeMetadata:
106
109
seq_lens_list : list [int ]
107
110
actual_seq_q_lens : Optional [list [int ]] = None
108
111
attn_mask : Optional [torch .Tensor ] = None
112
+ sin : torch .Tensor = None
113
+ cos : torch .Tensor = None
109
114
110
115
111
116
@dataclass
@@ -217,6 +222,9 @@ def __init__(self,
217
222
)
218
223
ascend_config = get_ascend_config ()
219
224
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
225
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
226
+ self .cos_cache = None
227
+ self .sin_cache = None
220
228
221
229
def reorder_batch (self , input_batch : "InputBatch" ,
222
230
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -348,6 +356,18 @@ def build_torchair_graph_dummy(
348
356
else :
349
357
attn_state = AscendAttentionState .DecodeOnly
350
358
num_decode_tokens = 1
359
+ sin = torch .ones (num_reqs ,
360
+ 1 ,
361
+ 1 ,
362
+ self .rope_dim ,
363
+ dtype = self .runner .dtype ,
364
+ device = device )
365
+ cos = torch .ones (num_reqs ,
366
+ 1 ,
367
+ 1 ,
368
+ self .rope_dim ,
369
+ dtype = self .runner .dtype ,
370
+ device = device )
351
371
decode_metadata = AscendMLADecodeMetadata (
352
372
input_positions = input_positions ,
353
373
block_table = block_table ,
@@ -356,7 +376,8 @@ def build_torchair_graph_dummy(
356
376
max_seq_lens = 1 ,
357
377
attn_mask = self .runner .spec_attn_mask ,
358
378
actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
359
- )
379
+ sin = sin ,
380
+ cos = cos )
360
381
return self .metadata_cls ( # type: ignore
361
382
num_input_tokens = num_actual_tokens ,
362
383
num_actual_tokens = num_actual_tokens ,
@@ -408,6 +429,14 @@ def build(
408
429
max_query_len = query_lens .max ().item ()
409
430
max_seq_lens = seq_lens .max ().item ()
410
431
query_start_loc = common_attn_metadata .query_start_loc
432
+ if self .cos_cache is None :
433
+ self .cos_cache = self .runner .get_model (
434
+ ).model .layers [0 ].self_attn .rotary_emb .cos_cached
435
+ self .sin_cache = self .runner .get_model (
436
+ ).model .layers [0 ].self_attn .rotary_emb .sin_cached
437
+ if self .cos_cache .dtype != self .runner .dtype :
438
+ self .cos_cache = self .cos_cache .to (self .runner .dtype )
439
+ self .sin_cache = self .sin_cache .to (self .runner .dtype )
411
440
412
441
prefill_metadata = None
413
442
chunked_context_metadata = None
@@ -454,18 +483,24 @@ def build(
454
483
chunk_seq_lens = chunk_seq_lens ,
455
484
workspace = self .chunked_prefill_workspace ,
456
485
)
457
-
486
+ prefill_input_positions = input_positions [tokens_start :]
487
+ cos = self .cos_cache [prefill_input_positions ].unsqueeze (
488
+ 1 ).unsqueeze (2 )
489
+ sin = self .sin_cache [prefill_input_positions ].unsqueeze (
490
+ 1 ).unsqueeze (2 )
458
491
prefill_metadata = AscendMLAPrefillMetadata (
459
492
attn_mask = self .runner .attn_mask ,
460
493
query_lens = query_lens [tokens_start :],
461
494
seq_lens = seq_lens ,
462
495
context_lens = seq_lens [tokens_start :],
463
- input_positions = input_positions [ tokens_start :] ,
496
+ input_positions = prefill_input_positions ,
464
497
block_table = block_table [reqs_start :, ...],
465
498
max_query_len = max_query_len ,
466
499
max_seq_lens = max_seq_lens ,
467
500
query_start_loc = prefill_query_start_loc ,
468
501
chunked_context = chunked_context_metadata ,
502
+ sin = sin ,
503
+ cos = cos ,
469
504
)
470
505
471
506
decode_metadata = None
@@ -510,8 +545,11 @@ def build(
510
545
actual_seq_q_lens = query_start_loc [1 :].tolist (
511
546
) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
512
547
num_reqs_pad_size ]
548
+ cos = self .cos_cache [input_positions ].unsqueeze (1 ).unsqueeze (2 )
549
+ sin = self .sin_cache [input_positions ].unsqueeze (1 ).unsqueeze (2 )
513
550
else :
514
551
seq_lens_list = seq_lens .tolist ()
552
+ cos , sin = None , None
515
553
516
554
decode_metadata = AscendMLADecodeMetadata (
517
555
input_positions = input_positions ,
@@ -521,7 +559,8 @@ def build(
521
559
max_seq_lens = max_seq_lens ,
522
560
attn_mask = self .runner .spec_attn_mask ,
523
561
actual_seq_q_lens = actual_seq_q_lens ,
524
- )
562
+ sin = sin ,
563
+ cos = cos )
525
564
526
565
return self .metadata_cls ( # type: ignore
527
566
num_actual_tokens = num_actual_tokens ,
@@ -1113,15 +1152,8 @@ def forward(
1113
1152
decode_k_nope = None
1114
1153
assert attn_metadata .decode is not None
1115
1154
if self .running_in_graph :
1116
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1117
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1118
- dtype = decode_hs_or_q_c .dtype )
1119
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1120
- dtype = decode_hs_or_q_c .dtype )
1121
- cos = cos [attn_metadata .decode .input_positions ]
1122
- sin = sin [attn_metadata .decode .input_positions ]
1123
- cos = cos [:, None , None , :]
1124
- sin = sin [:, None , None , :]
1155
+ cos = attn_metadata .decode .cos
1156
+ sin = attn_metadata .decode .sin
1125
1157
# Without explicitly controlling the order, IndexByTensor operations
1126
1158
# would be placed after `matmul W_KV_T` hindering the overlapping of
1127
1159
# KvRmsNormRopeCache and SingleRope.
@@ -1156,15 +1188,8 @@ def forward(
1156
1188
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1157
1189
if self .torchair_graph_enabled :
1158
1190
num_tokens = prefill_hs_or_q_c .shape [0 ]
1159
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1160
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1161
- dtype = prefill_q_pe .dtype )
1162
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1163
- dtype = prefill_q_pe .dtype )
1164
- cos = cos [attn_metadata .prefill .input_positions ]
1165
- sin = sin [attn_metadata .prefill .input_positions ]
1166
- cos = cos [:, None , None , :]
1167
- sin = sin [:, None , None , :]
1191
+ cos = attn_metadata .prefill .cos
1192
+ sin = attn_metadata .prefill .sin
1168
1193
1169
1194
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1170
1195
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments