9
9
MLAAttentionImpl )
10
10
from vllm .attention .backends .utils import PAD_SLOT_ID
11
11
from vllm .config import get_current_vllm_config
12
+ from vllm .forward_context import get_forward_context
12
13
from vllm .model_executor .layers .linear import (LinearBase ,
13
14
UnquantizedLinearMethod )
14
15
from vllm .utils import cdiv , round_down
@@ -93,6 +94,8 @@ class ChunkedContextMetadata:
93
94
max_query_len : int
94
95
max_seq_lens : int
95
96
chunked_context : Optional [ChunkedContextMetadata ] = None
97
+ sin : torch .Tensor = None
98
+ cos : torch .Tensor = None
96
99
97
100
98
101
@dataclass
@@ -105,6 +108,8 @@ class AscendMLADecodeMetadata:
105
108
max_seq_lens : int
106
109
seq_lens_list : list [int ]
107
110
attn_mask : Optional [torch .Tensor ] = None
111
+ sin : torch .Tensor = None
112
+ cos : torch .Tensor = None
108
113
109
114
110
115
@dataclass
@@ -215,6 +220,9 @@ def __init__(self,
215
220
)
216
221
ascend_config = get_ascend_config ()
217
222
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
223
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
224
+ self .cos_cache = None
225
+ self .sin_cache = None
218
226
219
227
def reorder_batch (self , input_batch : "InputBatch" ,
220
228
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -333,13 +341,27 @@ def build_torchair_graph_dummy(
333
341
- 1 ,
334
342
dtype = torch .int32 ,
335
343
device = device )
344
+ sin = torch .ones (num_reqs ,
345
+ 1 ,
346
+ 1 ,
347
+ self .rope_dim ,
348
+ dtype = self .runner .dtype ,
349
+ device = device )
350
+ cos = torch .ones (num_reqs ,
351
+ 1 ,
352
+ 1 ,
353
+ self .rope_dim ,
354
+ dtype = self .runner .dtype ,
355
+ device = device )
336
356
decode_metadata = AscendMLADecodeMetadata (
337
357
input_positions = input_positions ,
338
358
block_table = block_table ,
339
359
seq_lens = seq_lens ,
340
360
seq_lens_list = seq_lens .tolist (),
341
361
max_seq_lens = 1 ,
342
- attn_mask = self .runner .spec_attn_mask )
362
+ attn_mask = self .runner .spec_attn_mask ,
363
+ sin = sin ,
364
+ cos = cos )
343
365
return self .metadata_cls ( # type: ignore
344
366
num_input_tokens = num_actual_tokens ,
345
367
num_actual_tokens = num_actual_tokens ,
@@ -388,6 +410,14 @@ def build(
388
410
max_query_len = query_lens .max ().item ()
389
411
max_seq_lens = seq_lens .max ().item ()
390
412
query_start_loc = common_attn_metadata .query_start_loc
413
+ if self .cos_cache is None :
414
+ self .cos_cache = self .runner .get_model (
415
+ ).model .layers [0 ].self_attn .rotary_emb .cos_cached
416
+ self .sin_cache = self .runner .get_model (
417
+ ).model .layers [0 ].self_attn .rotary_emb .sin_cached
418
+ if self .cos_cache .dtype != self .runner .dtype :
419
+ self .cos_cache = self .cos_cache .to (self .runner .dtype )
420
+ self .sin_cache = self .sin_cache .to (self .runner .dtype )
391
421
392
422
prefill_metadata = None
393
423
chunked_context_metadata = None
@@ -434,18 +464,24 @@ def build(
434
464
chunk_seq_lens = chunk_seq_lens ,
435
465
workspace = self .chunked_prefill_workspace ,
436
466
)
437
-
467
+ prefill_input_positions = input_positions [tokens_start :]
468
+ cos = self .cos_cache [prefill_input_positions ].unsqueeze (
469
+ 1 ).unsqueeze (2 )
470
+ sin = self .sin_cache [prefill_input_positions ].unsqueeze (
471
+ 1 ).unsqueeze (2 )
438
472
prefill_metadata = AscendMLAPrefillMetadata (
439
473
attn_mask = self .runner .attn_mask ,
440
474
query_lens = query_lens [tokens_start :],
441
475
seq_lens = seq_lens ,
442
476
context_lens = seq_lens [tokens_start :],
443
- input_positions = input_positions [ tokens_start :] ,
477
+ input_positions = prefill_input_positions ,
444
478
block_table = block_table [reqs_start :, ...],
445
479
max_query_len = max_query_len ,
446
480
max_seq_lens = max_seq_lens ,
447
481
query_start_loc = prefill_query_start_loc ,
448
482
chunked_context = chunked_context_metadata ,
483
+ sin = sin ,
484
+ cos = cos ,
449
485
)
450
486
451
487
decode_metadata = None
@@ -486,14 +522,18 @@ def build(
486
522
dtype = input_positions .dtype ,
487
523
device = input_positions .device )
488
524
input_positions = torch .cat ([input_positions , padding_0 ])
525
+ cos = self .cos_cache [input_positions ].unsqueeze (1 ).unsqueeze (2 )
526
+ sin = self .sin_cache [input_positions ].unsqueeze (1 ).unsqueeze (2 )
489
527
490
528
decode_metadata = AscendMLADecodeMetadata (
491
529
input_positions = input_positions ,
492
530
block_table = block_table ,
493
531
seq_lens = seq_lens ,
494
532
seq_lens_list = seq_lens .tolist (),
495
533
max_seq_lens = max_seq_lens ,
496
- attn_mask = self .runner .spec_attn_mask )
534
+ attn_mask = self .runner .spec_attn_mask ,
535
+ sin = sin ,
536
+ cos = cos )
497
537
498
538
return self .metadata_cls ( # type: ignore
499
539
num_actual_tokens = num_actual_tokens ,
@@ -1042,9 +1082,7 @@ def forward(
1042
1082
if attn_metadata is None :
1043
1083
# Profiling run.
1044
1084
return output
1045
- self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
1046
- AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
1047
- ]
1085
+ self .running_in_graph = get_forward_context ().running_in_graph
1048
1086
num_actual_toks = attn_metadata .num_actual_tokens
1049
1087
if k_pe is None and not self .running_in_graph :
1050
1088
kv_c , k_pe = self .kv_a_proj_with_mqa (
@@ -1082,15 +1120,8 @@ def forward(
1082
1120
decode_k_nope = None
1083
1121
assert attn_metadata .decode is not None
1084
1122
if self .running_in_graph :
1085
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1086
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1087
- dtype = decode_hs_or_q_c .dtype )
1088
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1089
- dtype = decode_hs_or_q_c .dtype )
1090
- cos = cos [attn_metadata .decode .input_positions ]
1091
- sin = sin [attn_metadata .decode .input_positions ]
1092
- cos = cos [:, None , None , :]
1093
- sin = sin [:, None , None , :]
1123
+ cos = attn_metadata .decode .cos
1124
+ sin = attn_metadata .decode .sin
1094
1125
# Without explicitly controlling the order, IndexByTensor operations
1095
1126
# would be placed after `matmul W_KV_T` hindering the overlapping of
1096
1127
# KvRmsNormRopeCache and SingleRope.
@@ -1125,15 +1156,8 @@ def forward(
1125
1156
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1126
1157
if self .torchair_graph_enabled :
1127
1158
num_tokens = prefill_hs_or_q_c .shape [0 ]
1128
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1129
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1130
- dtype = prefill_q_pe .dtype )
1131
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1132
- dtype = prefill_q_pe .dtype )
1133
- cos = cos [attn_metadata .prefill .input_positions ]
1134
- sin = sin [attn_metadata .prefill .input_positions ]
1135
- cos = cos [:, None , None , :]
1136
- sin = sin [:, None , None , :]
1159
+ cos = attn_metadata .prefill .cos
1160
+ sin = attn_metadata .prefill .sin
1137
1161
1138
1162
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1139
1163
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments