11
11
from vllm .config import get_current_vllm_config
12
12
from vllm .model_executor .layers .linear import (LinearBase ,
13
13
UnquantizedLinearMethod )
14
- from vllm .platforms import current_platform
15
14
from vllm .utils import cdiv , round_down
16
15
17
16
from vllm_ascend import envs
@@ -82,6 +81,8 @@ class ChunkedContextMetadata:
82
81
max_query_len : int
83
82
max_seq_lens : int
84
83
chunked_context : Optional [ChunkedContextMetadata ] = None
84
+ sin : torch .Tensor = None
85
+ cos : torch .Tensor = None
85
86
86
87
87
88
@dataclass
@@ -95,7 +96,8 @@ class AscendMLADecodeMetadata:
95
96
seq_lens_list : list [int ]
96
97
actual_seq_q_lens : Optional [list [int ]] = None
97
98
attn_mask : Optional [torch .Tensor ] = None
98
- mc2_mask : Optional [torch .Tensor ] = None
99
+ sin : torch .Tensor = None
100
+ cos : torch .Tensor = None
99
101
100
102
101
103
@dataclass
@@ -207,13 +209,9 @@ def __init__(self,
207
209
)
208
210
ascend_config = get_ascend_config ()
209
211
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
210
-
211
- def generate_activate_mask (self , actual_seqs_num , batch_size ):
212
- mc2_mask = torch .zeros (batch_size ,
213
- dtype = torch .bool ,
214
- device = current_platform .device_type )
215
- mc2_mask [:actual_seqs_num ].fill_ (True )
216
- return mc2_mask
212
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
213
+ self .cos_cache = None
214
+ self .sin_cache = None
217
215
218
216
def reorder_batch (self , input_batch : "InputBatch" ,
219
217
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -345,7 +343,18 @@ def build_torchair_graph_dummy(
345
343
else :
346
344
attn_state = AscendAttentionState .DecodeOnly
347
345
num_decode_tokens = 1
348
- mc2_mask = self .generate_activate_mask (num_actual_tokens , num_reqs )
346
+ sin = torch .ones (num_reqs ,
347
+ 1 ,
348
+ 1 ,
349
+ self .rope_dim ,
350
+ dtype = self .runner .dtype ,
351
+ device = device )
352
+ cos = torch .ones (num_reqs ,
353
+ 1 ,
354
+ 1 ,
355
+ self .rope_dim ,
356
+ dtype = self .runner .dtype ,
357
+ device = device )
349
358
decode_metadata = AscendMLADecodeMetadata (
350
359
input_positions = input_positions ,
351
360
block_table = block_table ,
@@ -354,8 +363,8 @@ def build_torchair_graph_dummy(
354
363
max_seq_lens = 1 ,
355
364
attn_mask = self .runner .spec_attn_mask ,
356
365
actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
357
- mc2_mask = mc2_mask ,
358
- )
366
+ sin = sin ,
367
+ cos = cos )
359
368
return self .metadata_cls ( # type: ignore
360
369
num_input_tokens = num_actual_tokens ,
361
370
num_actual_tokens = num_actual_tokens ,
@@ -407,6 +416,16 @@ def build(
407
416
max_query_len = query_lens .max ().item ()
408
417
max_seq_lens = seq_lens .max ().item ()
409
418
query_start_loc = common_attn_metadata .query_start_loc
419
+ if self .cos_cache is None :
420
+ self .cos_cache = self .runner .get_model (
421
+ ).model .layers [0 ].self_attn .rotary_emb .cos_cached
422
+ self .sin_cache = self .runner .get_model (
423
+ ).model .layers [0 ].self_attn .rotary_emb .sin_cached
424
+ if self .cos_cache .dtype != self .runner .dtype : # type: ignore
425
+ self .cos_cache = self .cos_cache .to ( # type: ignore
426
+ self .runner .dtype ) # type: ignore
427
+ self .sin_cache = self .sin_cache .to ( # type: ignore
428
+ self .runner .dtype ) # type: ignore
410
429
411
430
prefill_metadata = None
412
431
chunked_context_metadata = None
@@ -453,18 +472,26 @@ def build(
453
472
chunk_seq_lens = chunk_seq_lens ,
454
473
workspace = self .chunked_prefill_workspace ,
455
474
)
456
-
475
+ prefill_input_positions = input_positions [tokens_start :]
476
+ cos = self .cos_cache [
477
+ prefill_input_positions ].unsqueeze ( # type: ignore
478
+ 1 ).unsqueeze (2 )
479
+ sin = self .sin_cache [
480
+ prefill_input_positions ].unsqueeze ( # type: ignore
481
+ 1 ).unsqueeze (2 )
457
482
prefill_metadata = AscendMLAPrefillMetadata (
458
483
attn_mask = self .runner .attn_mask ,
459
484
query_lens = query_lens [tokens_start :],
460
485
seq_lens = seq_lens ,
461
486
context_lens = seq_lens [tokens_start :],
462
- input_positions = input_positions [ tokens_start :] ,
487
+ input_positions = prefill_input_positions ,
463
488
block_table = block_table [reqs_start :, ...],
464
489
max_query_len = max_query_len ,
465
490
max_seq_lens = max_seq_lens ,
466
491
query_start_loc = prefill_query_start_loc ,
467
492
chunked_context = chunked_context_metadata ,
493
+ sin = sin ,
494
+ cos = cos ,
468
495
)
469
496
470
497
decode_metadata = None
@@ -509,10 +536,15 @@ def build(
509
536
actual_seq_q_lens = query_start_loc [1 :].tolist (
510
537
) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
511
538
num_reqs_pad_size ]
539
+ cos = self .cos_cache [
540
+ input_positions ].unsqueeze ( # type: ignore
541
+ 1 ).unsqueeze (2 )
542
+ sin = self .sin_cache [
543
+ input_positions ].unsqueeze ( # type: ignore
544
+ 1 ).unsqueeze (2 )
512
545
else :
513
546
seq_lens_list = seq_lens .tolist ()
514
- mc2_mask = self .generate_activate_mask (
515
- num_actual_tokens , num_reqs + num_reqs_pad_size )
547
+ cos , sin = None , None
516
548
517
549
decode_metadata = AscendMLADecodeMetadata (
518
550
input_positions = input_positions ,
@@ -522,8 +554,8 @@ def build(
522
554
max_seq_lens = max_seq_lens ,
523
555
attn_mask = self .runner .spec_attn_mask ,
524
556
actual_seq_q_lens = actual_seq_q_lens ,
525
- mc2_mask = mc2_mask ,
526
- )
557
+ sin = sin ,
558
+ cos = cos )
527
559
528
560
return self .metadata_cls ( # type: ignore
529
561
num_actual_tokens = num_actual_tokens ,
@@ -1115,15 +1147,8 @@ def forward(
1115
1147
decode_k_nope = None
1116
1148
assert attn_metadata .decode is not None
1117
1149
if self .running_in_graph :
1118
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1119
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1120
- dtype = decode_hs_or_q_c .dtype )
1121
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1122
- dtype = decode_hs_or_q_c .dtype )
1123
- cos = cos [attn_metadata .decode .input_positions ]
1124
- sin = sin [attn_metadata .decode .input_positions ]
1125
- cos = cos [:, None , None , :]
1126
- sin = sin [:, None , None , :]
1150
+ cos = attn_metadata .decode .cos
1151
+ sin = attn_metadata .decode .sin
1127
1152
# Without explicitly controlling the order, IndexByTensor operations
1128
1153
# would be placed after `matmul W_KV_T` hindering the overlapping of
1129
1154
# KvRmsNormRopeCache and SingleRope.
@@ -1158,15 +1183,8 @@ def forward(
1158
1183
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1159
1184
if self .torchair_graph_enabled :
1160
1185
num_tokens = prefill_hs_or_q_c .shape [0 ]
1161
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1162
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1163
- dtype = prefill_q_pe .dtype )
1164
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1165
- dtype = prefill_q_pe .dtype )
1166
- cos = cos [attn_metadata .prefill .input_positions ]
1167
- sin = sin [attn_metadata .prefill .input_positions ]
1168
- cos = cos [:, None , None , :]
1169
- sin = sin [:, None , None , :]
1186
+ cos = attn_metadata .prefill .cos
1187
+ sin = attn_metadata .prefill .sin
1170
1188
1171
1189
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1172
1190
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments