@@ -93,6 +93,8 @@ class ChunkedContextMetadata:
93
93
max_query_len : int
94
94
max_seq_lens : int
95
95
chunked_context : Optional [ChunkedContextMetadata ] = None
96
+ sin : torch .Tensor = None
97
+ cos : torch .Tensor = None
96
98
97
99
98
100
@dataclass
@@ -106,6 +108,8 @@ class AscendMLADecodeMetadata:
106
108
seq_lens_list : list [int ]
107
109
actual_seq_q_lens : Optional [list [int ]] = None
108
110
attn_mask : Optional [torch .Tensor ] = None
111
+ sin : torch .Tensor = None
112
+ cos : torch .Tensor = None
109
113
110
114
111
115
@dataclass
@@ -217,6 +221,9 @@ def __init__(self,
217
221
)
218
222
ascend_config = get_ascend_config ()
219
223
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
224
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
225
+ self .cos_cache = None
226
+ self .sin_cache = None
220
227
221
228
def reorder_batch (self , input_batch : "InputBatch" ,
222
229
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -348,6 +355,18 @@ def build_torchair_graph_dummy(
348
355
else :
349
356
attn_state = AscendAttentionState .DecodeOnly
350
357
num_decode_tokens = 1
358
+ sin = torch .ones (num_reqs ,
359
+ 1 ,
360
+ 1 ,
361
+ self .rope_dim ,
362
+ dtype = self .runner .dtype ,
363
+ device = device )
364
+ cos = torch .ones (num_reqs ,
365
+ 1 ,
366
+ 1 ,
367
+ self .rope_dim ,
368
+ dtype = self .runner .dtype ,
369
+ device = device )
351
370
decode_metadata = AscendMLADecodeMetadata (
352
371
input_positions = input_positions ,
353
372
block_table = block_table ,
@@ -356,7 +375,8 @@ def build_torchair_graph_dummy(
356
375
max_seq_lens = 1 ,
357
376
attn_mask = self .runner .spec_attn_mask ,
358
377
actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
359
- )
378
+ sin = sin ,
379
+ cos = cos )
360
380
return self .metadata_cls ( # type: ignore
361
381
num_input_tokens = num_actual_tokens ,
362
382
num_actual_tokens = num_actual_tokens ,
@@ -408,6 +428,16 @@ def build(
408
428
max_query_len = query_lens .max ().item ()
409
429
max_seq_lens = seq_lens .max ().item ()
410
430
query_start_loc = common_attn_metadata .query_start_loc
431
+ if self .cos_cache is None :
432
+ self .cos_cache = self .runner .get_model (
433
+ ).model .layers [0 ].self_attn .rotary_emb .cos_cached
434
+ self .sin_cache = self .runner .get_model (
435
+ ).model .layers [0 ].self_attn .rotary_emb .sin_cached
436
+ if self .cos_cache .dtype != self .runner .dtype : # type: ignore
437
+ self .cos_cache = self .cos_cache .to (
438
+ self .runner .dtype ) # type: ignore
439
+ self .sin_cache = self .sin_cache .to (
440
+ self .runner .dtype ) # type: ignore
411
441
412
442
prefill_metadata = None
413
443
chunked_context_metadata = None
@@ -454,18 +484,26 @@ def build(
454
484
chunk_seq_lens = chunk_seq_lens ,
455
485
workspace = self .chunked_prefill_workspace ,
456
486
)
457
-
487
+ prefill_input_positions = input_positions [tokens_start :]
488
+ cos = self .cos_cache [
489
+ prefill_input_positions ].unsqueeze ( # type: ignore
490
+ 1 ).unsqueeze (2 )
491
+ sin = self .sin_cache [
492
+ prefill_input_positions ].unsqueeze ( # type: ignore
493
+ 1 ).unsqueeze (2 )
458
494
prefill_metadata = AscendMLAPrefillMetadata (
459
495
attn_mask = self .runner .attn_mask ,
460
496
query_lens = query_lens [tokens_start :],
461
497
seq_lens = seq_lens ,
462
498
context_lens = seq_lens [tokens_start :],
463
- input_positions = input_positions [ tokens_start :] ,
499
+ input_positions = prefill_input_positions ,
464
500
block_table = block_table [reqs_start :, ...],
465
501
max_query_len = max_query_len ,
466
502
max_seq_lens = max_seq_lens ,
467
503
query_start_loc = prefill_query_start_loc ,
468
504
chunked_context = chunked_context_metadata ,
505
+ sin = sin ,
506
+ cos = cos ,
469
507
)
470
508
471
509
decode_metadata = None
@@ -510,8 +548,15 @@ def build(
510
548
actual_seq_q_lens = query_start_loc [1 :].tolist (
511
549
) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
512
550
num_reqs_pad_size ]
551
+ cos = self .cos_cache [
552
+ input_positions ].unsqueeze ( # type: ignore
553
+ 1 ).unsqueeze (2 )
554
+ sin = self .sin_cache [
555
+ input_positions ].unsqueeze ( # type: ignore
556
+ 1 ).unsqueeze (2 )
513
557
else :
514
558
seq_lens_list = seq_lens .tolist ()
559
+ cos , sin = None , None
515
560
516
561
decode_metadata = AscendMLADecodeMetadata (
517
562
input_positions = input_positions ,
@@ -521,7 +566,8 @@ def build(
521
566
max_seq_lens = max_seq_lens ,
522
567
attn_mask = self .runner .spec_attn_mask ,
523
568
actual_seq_q_lens = actual_seq_q_lens ,
524
- )
569
+ sin = sin ,
570
+ cos = cos )
525
571
526
572
return self .metadata_cls ( # type: ignore
527
573
num_actual_tokens = num_actual_tokens ,
@@ -1113,15 +1159,8 @@ def forward(
1113
1159
decode_k_nope = None
1114
1160
assert attn_metadata .decode is not None
1115
1161
if self .running_in_graph :
1116
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1117
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1118
- dtype = decode_hs_or_q_c .dtype )
1119
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1120
- dtype = decode_hs_or_q_c .dtype )
1121
- cos = cos [attn_metadata .decode .input_positions ]
1122
- sin = sin [attn_metadata .decode .input_positions ]
1123
- cos = cos [:, None , None , :]
1124
- sin = sin [:, None , None , :]
1162
+ cos = attn_metadata .decode .cos
1163
+ sin = attn_metadata .decode .sin
1125
1164
# Without explicitly controlling the order, IndexByTensor operations
1126
1165
# would be placed after `matmul W_KV_T` hindering the overlapping of
1127
1166
# KvRmsNormRopeCache and SingleRope.
@@ -1156,15 +1195,8 @@ def forward(
1156
1195
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1157
1196
if self .torchair_graph_enabled :
1158
1197
num_tokens = prefill_hs_or_q_c .shape [0 ]
1159
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1160
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1161
- dtype = prefill_q_pe .dtype )
1162
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1163
- dtype = prefill_q_pe .dtype )
1164
- cos = cos [attn_metadata .prefill .input_positions ]
1165
- sin = sin [attn_metadata .prefill .input_positions ]
1166
- cos = cos [:, None , None , :]
1167
- sin = sin [:, None , None , :]
1198
+ cos = attn_metadata .prefill .cos
1199
+ sin = attn_metadata .prefill .sin
1168
1200
1169
1201
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1170
1202
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments