11
11
from vllm .config import get_current_vllm_config
12
12
from vllm .model_executor .layers .linear import (LinearBase ,
13
13
UnquantizedLinearMethod )
14
+ from vllm .platforms import current_platform
14
15
from vllm .utils import cdiv , round_down
15
16
16
17
from vllm_ascend import envs
@@ -81,6 +82,8 @@ class ChunkedContextMetadata:
81
82
max_query_len : int
82
83
max_seq_lens : int
83
84
chunked_context : Optional [ChunkedContextMetadata ] = None
85
+ sin : torch .Tensor = None
86
+ cos : torch .Tensor = None
84
87
85
88
86
89
@dataclass
@@ -94,6 +97,9 @@ class AscendMLADecodeMetadata:
94
97
seq_lens_list : list [int ]
95
98
actual_seq_q_lens : Optional [list [int ]] = None
96
99
attn_mask : Optional [torch .Tensor ] = None
100
+ sin : torch .Tensor = None
101
+ cos : torch .Tensor = None
102
+ mc2_mask : Optional [torch .Tensor ] = None
97
103
98
104
99
105
@dataclass
@@ -205,6 +211,16 @@ def __init__(self,
205
211
)
206
212
ascend_config = get_ascend_config ()
207
213
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
214
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
215
+ self .cos_cache = None
216
+ self .sin_cache = None
217
+
218
+ def generate_activate_mask (self , actual_seqs_num , batch_size ):
219
+ mc2_mask = torch .zeros (batch_size ,
220
+ dtype = torch .bool ,
221
+ device = current_platform .device_type )
222
+ mc2_mask [:actual_seqs_num ].fill_ (True )
223
+ return mc2_mask
208
224
209
225
def reorder_batch (self , input_batch : "InputBatch" ,
210
226
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -317,7 +333,7 @@ def build_torchair_graph_dummy(
317
333
num_reqs , block_table )
318
334
num_tokens = num_reqs * self .runner .decode_token_per_req
319
335
seq_lens = torch .zeros (num_reqs , dtype = torch .int32 , device = device )
320
- seq_lens_list = seq_lens . tolist ()
336
+ seq_lens_list = [ 0 ] * num_reqs
321
337
input_positions = torch .zeros (num_tokens ,
322
338
dtype = torch .int32 ,
323
339
device = device ).long ()
@@ -336,6 +352,19 @@ def build_torchair_graph_dummy(
336
352
else :
337
353
attn_state = AscendAttentionState .DecodeOnly
338
354
num_decode_tokens = 1
355
+ sin = torch .ones (num_reqs ,
356
+ 1 ,
357
+ 1 ,
358
+ self .rope_dim ,
359
+ dtype = self .runner .dtype ,
360
+ device = device )
361
+ cos = torch .ones (num_reqs ,
362
+ 1 ,
363
+ 1 ,
364
+ self .rope_dim ,
365
+ dtype = self .runner .dtype ,
366
+ device = device )
367
+ mc2_mask = self .generate_activate_mask (num_actual_tokens , num_reqs )
339
368
decode_metadata = AscendMLADecodeMetadata (
340
369
input_positions = input_positions ,
341
370
block_table = block_table ,
@@ -344,7 +373,9 @@ def build_torchair_graph_dummy(
344
373
max_seq_lens = 1 ,
345
374
attn_mask = self .runner .spec_attn_mask ,
346
375
actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
347
- )
376
+ sin = sin ,
377
+ cos = cos ,
378
+ mc2_mask = mc2_mask )
348
379
return self .metadata_cls ( # type: ignore
349
380
num_input_tokens = num_actual_tokens ,
350
381
num_actual_tokens = num_actual_tokens ,
@@ -396,6 +427,16 @@ def build(
396
427
max_query_len = query_lens .max ().item ()
397
428
max_seq_lens = seq_lens .max ().item ()
398
429
query_start_loc = common_attn_metadata .query_start_loc
430
+ if self .cos_cache is None :
431
+ self .cos_cache = self .runner .get_model (
432
+ ).model .layers [0 ].self_attn .rotary_emb .cos_cached
433
+ self .sin_cache = self .runner .get_model (
434
+ ).model .layers [0 ].self_attn .rotary_emb .sin_cached
435
+ if self .cos_cache .dtype != self .runner .dtype : # type: ignore
436
+ self .cos_cache = self .cos_cache .to ( # type: ignore
437
+ self .runner .dtype ) # type: ignore
438
+ self .sin_cache = self .sin_cache .to ( # type: ignore
439
+ self .runner .dtype ) # type: ignore
399
440
400
441
prefill_metadata = None
401
442
chunked_context_metadata = None
@@ -442,24 +483,32 @@ def build(
442
483
chunk_seq_lens = chunk_seq_lens ,
443
484
workspace = self .chunked_prefill_workspace ,
444
485
)
445
-
486
+ prefill_input_positions = input_positions [tokens_start :]
487
+ cos = self .cos_cache [
488
+ prefill_input_positions ].unsqueeze ( # type: ignore
489
+ 1 ).unsqueeze (2 )
490
+ sin = self .sin_cache [
491
+ prefill_input_positions ].unsqueeze ( # type: ignore
492
+ 1 ).unsqueeze (2 )
446
493
prefill_metadata = AscendMLAPrefillMetadata (
447
494
attn_mask = self .runner .attn_mask ,
448
495
query_lens = query_lens [tokens_start :],
449
496
seq_lens = seq_lens ,
450
497
context_lens = seq_lens [tokens_start :],
451
- input_positions = input_positions [ tokens_start :] ,
498
+ input_positions = prefill_input_positions ,
452
499
block_table = block_table [reqs_start :, ...],
453
500
max_query_len = max_query_len ,
454
501
max_seq_lens = max_seq_lens ,
455
502
query_start_loc = prefill_query_start_loc ,
456
503
chunked_context = chunked_context_metadata ,
504
+ sin = sin ,
505
+ cos = cos ,
457
506
)
458
507
459
508
decode_metadata = None
460
509
use_torchair_graph = num_token_pad_size != - 1
461
510
if self ._num_decodes > 0 :
462
- actual_seq_q_lens = None
511
+ actual_seq_q_lens = query_start_loc [ 1 :]. tolist ()
463
512
max_seq_lens = seq_lens [:self ._num_decodes ].max ().item ()
464
513
seq_lens = seq_lens [:self ._num_decode_tokens ]
465
514
input_positions = input_positions [:self ._num_decode_tokens ]
@@ -498,8 +547,17 @@ def build(
498
547
actual_seq_q_lens = query_start_loc [1 :].tolist (
499
548
) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
500
549
num_reqs_pad_size ]
550
+ cos = self .cos_cache [
551
+ input_positions ].unsqueeze ( # type: ignore
552
+ 1 ).unsqueeze (2 )
553
+ sin = self .sin_cache [
554
+ input_positions ].unsqueeze ( # type: ignore
555
+ 1 ).unsqueeze (2 )
501
556
else :
502
557
seq_lens_list = seq_lens .tolist ()
558
+ cos , sin = None , None
559
+ mc2_mask = self .generate_activate_mask (
560
+ num_actual_tokens , num_reqs + num_reqs_pad_size )
503
561
504
562
decode_metadata = AscendMLADecodeMetadata (
505
563
input_positions = input_positions ,
@@ -509,7 +567,9 @@ def build(
509
567
max_seq_lens = max_seq_lens ,
510
568
attn_mask = self .runner .spec_attn_mask ,
511
569
actual_seq_q_lens = actual_seq_q_lens ,
512
- )
570
+ sin = sin ,
571
+ cos = cos ,
572
+ mc2_mask = mc2_mask )
513
573
514
574
return self .metadata_cls ( # type: ignore
515
575
num_actual_tokens = num_actual_tokens ,
@@ -968,11 +1028,13 @@ def _forward_decode(
968
1028
self .qk_rope_head_dim )
969
1029
input_layout = "BNSD"
970
1030
971
- # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
972
1031
if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
973
1032
assert num_tokens % self .spec_token_num == 0
1033
+ if self .enable_kv_nz :
1034
+ input_layout = "TND_NTD"
1035
+ else :
1036
+ input_layout = "TND"
974
1037
# [bs * q_seq_len, num_heads_per_rank, dim]
975
- input_layout = "TND"
976
1038
q_nope = q_nope .view (num_tokens , self .num_heads , - 1 )
977
1039
q_pe = q_pe .view (num_tokens , self .num_heads , - 1 )
978
1040
sparse_mode = 3
@@ -1101,15 +1163,8 @@ def forward(
1101
1163
decode_k_nope = None
1102
1164
assert attn_metadata .decode is not None
1103
1165
if self .running_in_graph :
1104
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1105
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1106
- dtype = decode_hs_or_q_c .dtype )
1107
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1108
- dtype = decode_hs_or_q_c .dtype )
1109
- cos = cos [attn_metadata .decode .input_positions ]
1110
- sin = sin [attn_metadata .decode .input_positions ]
1111
- cos = cos [:, None , None , :]
1112
- sin = sin [:, None , None , :]
1166
+ cos = attn_metadata .decode .cos
1167
+ sin = attn_metadata .decode .sin
1113
1168
# Without explicitly controlling the order, IndexByTensor operations
1114
1169
# would be placed after `matmul W_KV_T` hindering the overlapping of
1115
1170
# KvRmsNormRopeCache and SingleRope.
@@ -1144,15 +1199,8 @@ def forward(
1144
1199
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1145
1200
if self .torchair_graph_enabled :
1146
1201
num_tokens = prefill_hs_or_q_c .shape [0 ]
1147
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1148
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1149
- dtype = prefill_q_pe .dtype )
1150
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1151
- dtype = prefill_q_pe .dtype )
1152
- cos = cos [attn_metadata .prefill .input_positions ]
1153
- sin = sin [attn_metadata .prefill .input_positions ]
1154
- cos = cos [:, None , None , :]
1155
- sin = sin [:, None , None , :]
1202
+ cos = attn_metadata .prefill .cos
1203
+ sin = attn_metadata .prefill .sin
1156
1204
1157
1205
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1158
1206
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments