@@ -82,6 +82,8 @@ class ChunkedContextMetadata:
82
82
max_query_len : int
83
83
max_seq_lens : int
84
84
chunked_context : Optional [ChunkedContextMetadata ] = None
85
+ sin : torch .Tensor = None
86
+ cos : torch .Tensor = None
85
87
86
88
87
89
@dataclass
@@ -95,6 +97,8 @@ class AscendMLADecodeMetadata:
95
97
seq_lens_list : list [int ]
96
98
actual_seq_q_lens : Optional [list [int ]] = None
97
99
attn_mask : Optional [torch .Tensor ] = None
100
+ sin : torch .Tensor = None
101
+ cos : torch .Tensor = None
98
102
mc2_mask : Optional [torch .Tensor ] = None
99
103
100
104
@@ -207,9 +211,14 @@ def __init__(self,
207
211
)
208
212
ascend_config = get_ascend_config ()
209
213
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
210
-
211
- def generate_active_mask (self , actual_seqs_num , batch_size ):
212
- mc2_mask = torch .zeros (batch_size , dtype = torch .bool , device = current_platform .device_type )
214
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
215
+ self .cos_cache = None
216
+ self .sin_cache = None
217
+
218
+ def generate_activate_mask (self , actual_seqs_num , batch_size ):
219
+ mc2_mask = torch .zeros (batch_size ,
220
+ dtype = torch .bool ,
221
+ device = current_platform .device_type )
213
222
mc2_mask [:actual_seqs_num ].fill_ (True )
214
223
return mc2_mask
215
224
@@ -343,7 +352,19 @@ def build_torchair_graph_dummy(
343
352
else :
344
353
attn_state = AscendAttentionState .DecodeOnly
345
354
num_decode_tokens = 1
346
- mc2_mask = self .generate_active_mask (num_actual_tokens , num_reqs )
355
+ sin = torch .ones (num_reqs ,
356
+ 1 ,
357
+ 1 ,
358
+ self .rope_dim ,
359
+ dtype = self .runner .dtype ,
360
+ device = device )
361
+ cos = torch .ones (num_reqs ,
362
+ 1 ,
363
+ 1 ,
364
+ self .rope_dim ,
365
+ dtype = self .runner .dtype ,
366
+ device = device )
367
+ mc2_mask = self .generate_activate_mask (num_actual_tokens , num_reqs )
347
368
decode_metadata = AscendMLADecodeMetadata (
348
369
input_positions = input_positions ,
349
370
block_table = block_table ,
@@ -352,8 +373,9 @@ def build_torchair_graph_dummy(
352
373
max_seq_lens = 1 ,
353
374
attn_mask = self .runner .spec_attn_mask ,
354
375
actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
355
- mc2_mask = mc2_mask ,
356
- )
376
+ sin = sin ,
377
+ cos = cos ,
378
+ mc2_mask = mc2_mask )
357
379
return self .metadata_cls ( # type: ignore
358
380
num_input_tokens = num_actual_tokens ,
359
381
num_actual_tokens = num_actual_tokens ,
@@ -405,6 +427,16 @@ def build(
405
427
max_query_len = query_lens .max ().item ()
406
428
max_seq_lens = seq_lens .max ().item ()
407
429
query_start_loc = common_attn_metadata .query_start_loc
430
+ if self .cos_cache is None :
431
+ self .cos_cache = self .runner .get_model (
432
+ ).model .layers [0 ].self_attn .rotary_emb .cos_cached
433
+ self .sin_cache = self .runner .get_model (
434
+ ).model .layers [0 ].self_attn .rotary_emb .sin_cached
435
+ if self .cos_cache .dtype != self .runner .dtype : # type: ignore
436
+ self .cos_cache = self .cos_cache .to ( # type: ignore
437
+ self .runner .dtype ) # type: ignore
438
+ self .sin_cache = self .sin_cache .to ( # type: ignore
439
+ self .runner .dtype ) # type: ignore
408
440
409
441
prefill_metadata = None
410
442
chunked_context_metadata = None
@@ -451,18 +483,26 @@ def build(
451
483
chunk_seq_lens = chunk_seq_lens ,
452
484
workspace = self .chunked_prefill_workspace ,
453
485
)
454
-
486
+ prefill_input_positions = input_positions [tokens_start :]
487
+ cos = self .cos_cache [
488
+ prefill_input_positions ].unsqueeze ( # type: ignore
489
+ 1 ).unsqueeze (2 )
490
+ sin = self .sin_cache [
491
+ prefill_input_positions ].unsqueeze ( # type: ignore
492
+ 1 ).unsqueeze (2 )
455
493
prefill_metadata = AscendMLAPrefillMetadata (
456
494
attn_mask = self .runner .attn_mask ,
457
495
query_lens = query_lens [tokens_start :],
458
496
seq_lens = seq_lens ,
459
497
context_lens = seq_lens [tokens_start :],
460
- input_positions = input_positions [ tokens_start :] ,
498
+ input_positions = prefill_input_positions ,
461
499
block_table = block_table [reqs_start :, ...],
462
500
max_query_len = max_query_len ,
463
501
max_seq_lens = max_seq_lens ,
464
502
query_start_loc = prefill_query_start_loc ,
465
503
chunked_context = chunked_context_metadata ,
504
+ sin = sin ,
505
+ cos = cos ,
466
506
)
467
507
468
508
decode_metadata = None
@@ -507,9 +547,17 @@ def build(
507
547
actual_seq_q_lens = query_start_loc [1 :].tolist (
508
548
) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
509
549
num_reqs_pad_size ]
550
+ cos = self .cos_cache [
551
+ input_positions ].unsqueeze ( # type: ignore
552
+ 1 ).unsqueeze (2 )
553
+ sin = self .sin_cache [
554
+ input_positions ].unsqueeze ( # type: ignore
555
+ 1 ).unsqueeze (2 )
510
556
else :
511
557
seq_lens_list = seq_lens .tolist ()
512
- mc2_mask = self .generate_active_mask (num_actual_tokens , num_reqs )
558
+ cos , sin = None , None
559
+ mc2_mask = self .generate_activate_mask (
560
+ num_actual_tokens , num_reqs + num_reqs_pad_size )
513
561
514
562
decode_metadata = AscendMLADecodeMetadata (
515
563
input_positions = input_positions ,
@@ -519,8 +567,9 @@ def build(
519
567
max_seq_lens = max_seq_lens ,
520
568
attn_mask = self .runner .spec_attn_mask ,
521
569
actual_seq_q_lens = actual_seq_q_lens ,
522
- mc2_mask = mc2_mask ,
523
- )
570
+ sin = sin ,
571
+ cos = cos ,
572
+ mc2_mask = mc2_mask )
524
573
525
574
return self .metadata_cls ( # type: ignore
526
575
num_actual_tokens = num_actual_tokens ,
@@ -1112,15 +1161,8 @@ def forward(
1112
1161
decode_k_nope = None
1113
1162
assert attn_metadata .decode is not None
1114
1163
if self .running_in_graph :
1115
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1116
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1117
- dtype = decode_hs_or_q_c .dtype )
1118
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1119
- dtype = decode_hs_or_q_c .dtype )
1120
- cos = cos [attn_metadata .decode .input_positions ]
1121
- sin = sin [attn_metadata .decode .input_positions ]
1122
- cos = cos [:, None , None , :]
1123
- sin = sin [:, None , None , :]
1164
+ cos = attn_metadata .decode .cos
1165
+ sin = attn_metadata .decode .sin
1124
1166
# Without explicitly controlling the order, IndexByTensor operations
1125
1167
# would be placed after `matmul W_KV_T` hindering the overlapping of
1126
1168
# KvRmsNormRopeCache and SingleRope.
@@ -1155,15 +1197,8 @@ def forward(
1155
1197
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1156
1198
if self .torchair_graph_enabled :
1157
1199
num_tokens = prefill_hs_or_q_c .shape [0 ]
1158
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1159
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1160
- dtype = prefill_q_pe .dtype )
1161
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1162
- dtype = prefill_q_pe .dtype )
1163
- cos = cos [attn_metadata .prefill .input_positions ]
1164
- sin = sin [attn_metadata .prefill .input_positions ]
1165
- cos = cos [:, None , None , :]
1166
- sin = sin [:, None , None , :]
1200
+ cos = attn_metadata .prefill .cos
1201
+ sin = attn_metadata .prefill .sin
1167
1202
1168
1203
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1169
1204
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments