@@ -81,6 +81,8 @@ class ChunkedContextMetadata:
81
81
max_query_len : int
82
82
max_seq_lens : int
83
83
chunked_context : Optional [ChunkedContextMetadata ] = None
84
+ sin : torch .Tensor = None
85
+ cos : torch .Tensor = None
84
86
85
87
86
88
@dataclass
@@ -94,6 +96,8 @@ class AscendMLADecodeMetadata:
94
96
seq_lens_list : list [int ]
95
97
actual_seq_q_lens : Optional [list [int ]] = None
96
98
attn_mask : Optional [torch .Tensor ] = None
99
+ sin : torch .Tensor = None
100
+ cos : torch .Tensor = None
97
101
98
102
99
103
@dataclass
@@ -205,6 +209,9 @@ def __init__(self,
205
209
)
206
210
ascend_config = get_ascend_config ()
207
211
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
212
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
213
+ self .cos_cache = None
214
+ self .sin_cache = None
208
215
209
216
def reorder_batch (self , input_batch : "InputBatch" ,
210
217
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -336,6 +343,18 @@ def build_torchair_graph_dummy(
336
343
else :
337
344
attn_state = AscendAttentionState .DecodeOnly
338
345
num_decode_tokens = 1
346
+ sin = torch .ones (num_reqs ,
347
+ 1 ,
348
+ 1 ,
349
+ self .rope_dim ,
350
+ dtype = self .runner .dtype ,
351
+ device = device )
352
+ cos = torch .ones (num_reqs ,
353
+ 1 ,
354
+ 1 ,
355
+ self .rope_dim ,
356
+ dtype = self .runner .dtype ,
357
+ device = device )
339
358
decode_metadata = AscendMLADecodeMetadata (
340
359
input_positions = input_positions ,
341
360
block_table = block_table ,
@@ -344,7 +363,8 @@ def build_torchair_graph_dummy(
344
363
max_seq_lens = 1 ,
345
364
attn_mask = self .runner .spec_attn_mask ,
346
365
actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
347
- )
366
+ sin = sin ,
367
+ cos = cos )
348
368
return self .metadata_cls ( # type: ignore
349
369
num_input_tokens = num_actual_tokens ,
350
370
num_actual_tokens = num_actual_tokens ,
@@ -396,6 +416,16 @@ def build(
396
416
max_query_len = query_lens .max ().item ()
397
417
max_seq_lens = seq_lens .max ().item ()
398
418
query_start_loc = common_attn_metadata .query_start_loc
419
+ if self .cos_cache is None :
420
+ self .cos_cache = self .runner .get_model (
421
+ ).model .layers [0 ].self_attn .rotary_emb .cos_cached
422
+ self .sin_cache = self .runner .get_model (
423
+ ).model .layers [0 ].self_attn .rotary_emb .sin_cached
424
+ if self .cos_cache .dtype != self .runner .dtype : # type: ignore
425
+ self .cos_cache = self .cos_cache .to ( # type: ignore
426
+ self .runner .dtype ) # type: ignore
427
+ self .sin_cache = self .sin_cache .to ( # type: ignore
428
+ self .runner .dtype ) # type: ignore
399
429
400
430
prefill_metadata = None
401
431
chunked_context_metadata = None
@@ -442,18 +472,26 @@ def build(
442
472
chunk_seq_lens = chunk_seq_lens ,
443
473
workspace = self .chunked_prefill_workspace ,
444
474
)
445
-
475
+ prefill_input_positions = input_positions [tokens_start :]
476
+ cos = self .cos_cache [
477
+ prefill_input_positions ].unsqueeze ( # type: ignore
478
+ 1 ).unsqueeze (2 )
479
+ sin = self .sin_cache [
480
+ prefill_input_positions ].unsqueeze ( # type: ignore
481
+ 1 ).unsqueeze (2 )
446
482
prefill_metadata = AscendMLAPrefillMetadata (
447
483
attn_mask = self .runner .attn_mask ,
448
484
query_lens = query_lens [tokens_start :],
449
485
seq_lens = seq_lens ,
450
486
context_lens = seq_lens [tokens_start :],
451
- input_positions = input_positions [ tokens_start :] ,
487
+ input_positions = prefill_input_positions ,
452
488
block_table = block_table [reqs_start :, ...],
453
489
max_query_len = max_query_len ,
454
490
max_seq_lens = max_seq_lens ,
455
491
query_start_loc = prefill_query_start_loc ,
456
492
chunked_context = chunked_context_metadata ,
493
+ sin = sin ,
494
+ cos = cos ,
457
495
)
458
496
459
497
decode_metadata = None
@@ -498,8 +536,15 @@ def build(
498
536
actual_seq_q_lens = query_start_loc [1 :].tolist (
499
537
) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
500
538
num_reqs_pad_size ]
539
+ cos = self .cos_cache [
540
+ input_positions ].unsqueeze ( # type: ignore
541
+ 1 ).unsqueeze (2 )
542
+ sin = self .sin_cache [
543
+ input_positions ].unsqueeze ( # type: ignore
544
+ 1 ).unsqueeze (2 )
501
545
else :
502
546
seq_lens_list = seq_lens .tolist ()
547
+ cos , sin = None , None
503
548
504
549
decode_metadata = AscendMLADecodeMetadata (
505
550
input_positions = input_positions ,
@@ -509,7 +554,8 @@ def build(
509
554
max_seq_lens = max_seq_lens ,
510
555
attn_mask = self .runner .spec_attn_mask ,
511
556
actual_seq_q_lens = actual_seq_q_lens ,
512
- )
557
+ sin = sin ,
558
+ cos = cos )
513
559
514
560
return self .metadata_cls ( # type: ignore
515
561
num_actual_tokens = num_actual_tokens ,
@@ -1101,15 +1147,8 @@ def forward(
1101
1147
decode_k_nope = None
1102
1148
assert attn_metadata .decode is not None
1103
1149
if self .running_in_graph :
1104
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1105
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1106
- dtype = decode_hs_or_q_c .dtype )
1107
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1108
- dtype = decode_hs_or_q_c .dtype )
1109
- cos = cos [attn_metadata .decode .input_positions ]
1110
- sin = sin [attn_metadata .decode .input_positions ]
1111
- cos = cos [:, None , None , :]
1112
- sin = sin [:, None , None , :]
1150
+ cos = attn_metadata .decode .cos
1151
+ sin = attn_metadata .decode .sin
1113
1152
# Without explicitly controlling the order, IndexByTensor operations
1114
1153
# would be placed after `matmul W_KV_T` hindering the overlapping of
1115
1154
# KvRmsNormRopeCache and SingleRope.
@@ -1144,15 +1183,8 @@ def forward(
1144
1183
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1145
1184
if self .torchair_graph_enabled :
1146
1185
num_tokens = prefill_hs_or_q_c .shape [0 ]
1147
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1148
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1149
- dtype = prefill_q_pe .dtype )
1150
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1151
- dtype = prefill_q_pe .dtype )
1152
- cos = cos [attn_metadata .prefill .input_positions ]
1153
- sin = sin [attn_metadata .prefill .input_positions ]
1154
- cos = cos [:, None , None , :]
1155
- sin = sin [:, None , None , :]
1186
+ cos = attn_metadata .prefill .cos
1187
+ sin = attn_metadata .prefill .sin
1156
1188
1157
1189
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1158
1190
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments