@@ -104,6 +104,7 @@ class AscendMLADecodeMetadata:
104
104
seq_lens : torch .Tensor
105
105
max_seq_lens : int
106
106
seq_lens_list : list [int ]
107
+ actual_seq_q_lens : Optional [list [int ]] = None
107
108
attn_mask : Optional [torch .Tensor ] = None
108
109
109
110
@@ -138,6 +139,7 @@ class AscendMLAMetadata:
138
139
num_input_tokens : int = 0 # Number of tokens including padding.
139
140
140
141
enable_dbo_across_dp : bool = False
142
+ is_mtp_model : bool = False
141
143
142
144
query_lens : Optional [list [int ]] = None
143
145
# The dimension of the attention heads
@@ -313,48 +315,64 @@ def _get_graph_runner_block_tables(
313
315
return graph_block_tables [:num_seqs , :max_blocks ]
314
316
315
317
def build_torchair_graph_dummy (
316
- self , num_reqs : int , num_actual_tokens : int ) -> AscendMLAMetadata :
318
+ self ,
319
+ num_reqs : int ,
320
+ num_actual_tokens : int ,
321
+ is_mtp_model : bool = False ,
322
+ ) -> AscendMLAMetadata :
317
323
device = self .runner .device
318
324
_ , max_blocks = self .runner .graph_block_tables .shape
319
325
block_table = torch .zeros ((num_reqs , max_blocks ),
320
326
dtype = torch .int32 ,
321
327
device = device )
322
328
block_table = self ._get_graph_runner_block_tables (
323
329
num_reqs , block_table )
324
- seq_lens = torch .ones (num_reqs , dtype = torch .int32 , device = device )
325
- input_positions = torch .zeros (num_reqs ,
330
+ num_tokens = num_reqs * self .runner .decode_token_per_req
331
+ seq_lens = torch .zeros (num_reqs , dtype = torch .int32 , device = device )
332
+ seq_lens_list = seq_lens .tolist ()
333
+ input_positions = torch .zeros (num_tokens ,
326
334
dtype = torch .int32 ,
327
335
device = device ).long ()
328
- slot_mapping = torch .full ((num_reqs , ),
336
+ slot_mapping = torch .full ((num_tokens , ),
329
337
PAD_SLOT_ID ,
330
338
dtype = torch .int32 ,
331
339
device = device )
332
340
query_start_loc = torch .full ((num_reqs , ),
333
341
- 1 ,
334
342
dtype = torch .int32 ,
335
343
device = device )
344
+ if self .runner .speculative_config is not None and \
345
+ self .runner .speculative_config .method == 'deepseek_mtp' and not is_mtp_model :
346
+ attn_state = AscendAttentionState .SpecDecoding
347
+ num_decode_tokens = 2
348
+ else :
349
+ attn_state = AscendAttentionState .DecodeOnly
350
+ num_decode_tokens = 1
336
351
decode_metadata = AscendMLADecodeMetadata (
337
352
input_positions = input_positions ,
338
353
block_table = block_table ,
339
354
seq_lens = seq_lens ,
340
- seq_lens_list = seq_lens . tolist () ,
355
+ seq_lens_list = seq_lens_list ,
341
356
max_seq_lens = 1 ,
342
- attn_mask = self .runner .spec_attn_mask )
357
+ attn_mask = self .runner .spec_attn_mask ,
358
+ actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
359
+ )
343
360
return self .metadata_cls ( # type: ignore
344
361
num_input_tokens = num_actual_tokens ,
345
362
num_actual_tokens = num_actual_tokens ,
346
363
slot_mapping = slot_mapping ,
347
364
head_dim = self .runner .model_config .get_head_size (),
348
365
num_decodes = 1 ,
349
- num_decode_tokens = 1 ,
366
+ num_decode_tokens = num_decode_tokens ,
350
367
num_prefills = 0 ,
351
368
attn_mask = self .runner .attn_mask ,
352
- attn_state = AscendAttentionState . DecodeOnly ,
369
+ attn_state = attn_state ,
353
370
prefill = None ,
354
371
decode = decode_metadata ,
355
372
query_start_loc = query_start_loc ,
356
373
seq_lens = seq_lens ,
357
374
block_tables = block_table ,
375
+ is_mtp_model = is_mtp_model ,
358
376
)
359
377
360
378
def build (
@@ -364,8 +382,10 @@ def build(
364
382
max_query_len : int ,
365
383
common_attn_metadata : CommonAttentionMetadata ,
366
384
common_prefix_len : Optional [int ] = None ,
367
- graph_pad_size : int = - 1 ,
385
+ num_token_pad_size : int = - 1 ,
386
+ num_reqs_pad_size : int = 0 ,
368
387
enable_dbo_across_dp : bool = False ,
388
+ is_mtp_model : bool = False ,
369
389
) -> AscendMLAMetadata :
370
390
assert self ._num_decodes + self ._num_prefills == num_reqs
371
391
@@ -449,8 +469,9 @@ def build(
449
469
)
450
470
451
471
decode_metadata = None
452
- use_torchair_graph = graph_pad_size != - 1
472
+ use_torchair_graph = num_token_pad_size != - 1
453
473
if self ._num_decodes > 0 :
474
+ actual_seq_q_lens = None
454
475
max_seq_lens = seq_lens [:self ._num_decodes ].max ().item ()
455
476
seq_lens = seq_lens [:self ._num_decode_tokens ]
456
477
input_positions = input_positions [:self ._num_decode_tokens ]
@@ -459,41 +480,48 @@ def build(
459
480
AscendAttentionState .DecodeOnly ,
460
481
AscendAttentionState .SpecDecoding
461
482
]:
462
- num_seqs = len (seq_lens )
463
- if graph_pad_size != 0 :
464
- pad_value = 1
465
- padded_seq_lens = seq_lens .tolist () + [pad_value
466
- ] * graph_pad_size
483
+ if num_token_pad_size != 0 :
484
+ pad_value = 0
485
+ padded_seq_lens = seq_lens .tolist (
486
+ ) + [pad_value ] * num_reqs_pad_size
467
487
else :
468
488
padded_seq_lens = seq_lens .tolist ()
469
489
470
490
seq_lens = torch .from_numpy (
471
491
np .array (padded_seq_lens ).astype (np .int32 ))
472
- padding = torch .full ((graph_pad_size , ),
492
+ seq_lens_list = padded_seq_lens
493
+ padding = torch .full ((num_token_pad_size , ),
473
494
PAD_SLOT_ID ,
474
495
dtype = slot_mapping .dtype ,
475
496
device = slot_mapping .device )
476
497
slot_mapping = torch .cat ([slot_mapping , padding ])
477
498
block_table_padding = torch .zeros (
478
- (graph_pad_size , ) + block_table .shape [1 :],
499
+ (num_reqs_pad_size , ) + block_table .shape [1 :],
479
500
dtype = block_table .dtype ,
480
501
device = block_table .device )
481
502
block_table = torch .cat ([block_table , block_table_padding ],
482
503
dim = 0 )
483
504
block_table = self ._get_graph_runner_block_tables (
484
- num_seqs + graph_pad_size , block_table )
485
- padding_0 = torch .zeros (graph_pad_size ,
505
+ num_reqs + num_reqs_pad_size , block_table )
506
+ padding_0 = torch .zeros (num_token_pad_size ,
486
507
dtype = input_positions .dtype ,
487
508
device = input_positions .device )
488
509
input_positions = torch .cat ([input_positions , padding_0 ])
510
+ actual_seq_q_lens = query_start_loc [1 :].tolist (
511
+ ) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
512
+ num_reqs_pad_size ]
513
+ else :
514
+ seq_lens_list = seq_lens .tolist ()
489
515
490
516
decode_metadata = AscendMLADecodeMetadata (
491
517
input_positions = input_positions ,
492
518
block_table = block_table ,
493
519
seq_lens = seq_lens ,
494
- seq_lens_list = seq_lens . tolist () ,
520
+ seq_lens_list = seq_lens_list ,
495
521
max_seq_lens = max_seq_lens ,
496
- attn_mask = self .runner .spec_attn_mask )
522
+ attn_mask = self .runner .spec_attn_mask ,
523
+ actual_seq_q_lens = actual_seq_q_lens ,
524
+ )
497
525
498
526
return self .metadata_cls ( # type: ignore
499
527
num_actual_tokens = num_actual_tokens ,
@@ -510,7 +538,9 @@ def build(
510
538
query_start_loc = query_start_loc ,
511
539
block_tables = block_table ,
512
540
seq_lens = seq_lens ,
513
- enable_dbo_across_dp = enable_dbo_across_dp )
541
+ enable_dbo_across_dp = enable_dbo_across_dp ,
542
+ is_mtp_model = is_mtp_model ,
543
+ )
514
544
515
545
516
546
class AscendMLAImpl (MLAAttentionImpl ):
@@ -933,31 +963,10 @@ def _forward_decode(
933
963
assert decode_meta is not None
934
964
num_tokens = q_nope .size (0 )
935
965
if self .running_in_graph :
936
- # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
937
- if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
938
- assert num_tokens % self .spec_token_num == 0
939
- q_nope = q_nope .view (num_tokens // (self .spec_token_num + 1 ),
940
- self .spec_token_num + 1 , self .num_heads ,
941
- - 1 )
942
- q_pe = q_pe .view (num_tokens // (self .spec_token_num + 1 ),
943
- self .spec_token_num + 1 , self .num_heads , - 1 )
944
- if not self .enable_kv_nz :
945
- q_nope = q_nope .transpose (1 , 2 ).contiguous ()
946
- q_pe = q_pe .transpose (1 , 2 ).contiguous ()
947
- sparse_mode = 3
948
- spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
949
- else :
950
- if self .enable_kv_nz :
951
- q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
952
- q_pe = q_pe .view (num_tokens , 1 , self .num_heads , - 1 )
953
- else :
954
- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
955
- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
956
- sparse_mode = 0
957
- spec_attn_mask = None
958
966
# shape of knope/k_pe for npu graph mode should be:
959
967
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
960
968
block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
969
+ actual_seq_lengths = None
961
970
if self .enable_kv_nz :
962
971
k_nope = k_nope .view (- 1 , self .num_kv_heads ,
963
972
self .kv_lora_rank // 16 , block_size , 16 )
@@ -971,6 +980,26 @@ def _forward_decode(
971
980
self .qk_rope_head_dim )
972
981
input_layout = "BNSD"
973
982
983
+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
984
+ if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
985
+ assert num_tokens % self .spec_token_num == 0
986
+ # [bs * q_seq_len, num_heads_per_rank, dim]
987
+ input_layout = "TND"
988
+ q_nope = q_nope .view (num_tokens , self .num_heads , - 1 )
989
+ q_pe = q_pe .view (num_tokens , self .num_heads , - 1 )
990
+ sparse_mode = 3
991
+ spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
992
+ actual_seq_lengths = decode_meta .actual_seq_q_lens
993
+ else :
994
+ if self .enable_kv_nz :
995
+ q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
996
+ q_pe = q_pe .view (num_tokens , 1 , self .num_heads , - 1 )
997
+ else :
998
+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
999
+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
1000
+ sparse_mode = 0
1001
+ spec_attn_mask = None
1002
+
974
1003
attn_output , _ = torch_npu .npu_fused_infer_attention_score (
975
1004
q_nope ,
976
1005
k_nope ,
@@ -988,7 +1017,7 @@ def _forward_decode(
988
1017
block_table = decode_meta .block_table ,
989
1018
block_size = block_size ,
990
1019
actual_seq_lengths_kv = decode_meta .seq_lens_list ,
991
- )
1020
+ actual_seq_lengths = actual_seq_lengths )
992
1021
else :
993
1022
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
994
1023
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
@@ -1042,6 +1071,8 @@ def forward(
1042
1071
if attn_metadata is None :
1043
1072
# Profiling run.
1044
1073
return output
1074
+ # mtp model is not support for graph mode yet
1075
+ self .torchair_graph_enabled = self .torchair_graph_enabled and not attn_metadata .is_mtp_model
1045
1076
self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
1046
1077
AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
1047
1078
]
0 commit comments