8
8
AttentionMetadata ,
9
9
MLAAttentionImpl )
10
10
from vllm .attention .backends .utils import PAD_SLOT_ID
11
+ from vllm .config import get_current_vllm_config
11
12
from vllm .model_executor .layers .linear import (LinearBase ,
12
13
UnquantizedLinearMethod )
13
14
@@ -86,6 +87,7 @@ class AscendMLADecodeMetadata:
86
87
seq_lens : torch .Tensor
87
88
max_seq_lens : int
88
89
seq_lens_list : list [int ]
90
+ attn_mask : Optional [torch .Tensor ] = None
89
91
90
92
91
93
@dataclass
@@ -169,6 +171,8 @@ def __init__(self,
169
171
self .runner = runner
170
172
scheduler_config = runner .scheduler_config
171
173
self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
174
+ ascend_config = get_ascend_config ()
175
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
172
176
173
177
def reorder_batch (self , input_batch : "InputBatch" ,
174
178
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -185,16 +189,24 @@ def reorder_batch(self, input_batch: "InputBatch",
185
189
186
190
for i , req_id in enumerate (input_batch .req_ids ):
187
191
num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
188
- # for now treat 1 scheduled token as "decode" even if its not,
189
- # we should update this to something like < 8 in the future but
190
- # currently the TritonMLA._forward_decode only supports
191
- # num_tokens = 1
192
- if num_tokens == 1 :
193
- decodes .append (i )
194
- num_decode_tokens += num_tokens
192
+ num_spec_tokens = len (
193
+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
194
+ # For torch air graph mode we treat spec decoding as decode.
195
+ if self .torchair_graph_enabled :
196
+ if num_tokens - num_spec_tokens == 1 :
197
+ decodes .append (i )
198
+ num_decode_tokens += num_tokens
199
+ else :
200
+ prefills .append (i )
201
+ num_prefill_tokens += num_tokens
202
+ # For eager mode we treat spec decoding as chunked prefill.
195
203
else :
196
- prefills .append (i )
197
- num_prefill_tokens += num_tokens
204
+ if num_tokens == 1 :
205
+ decodes .append (i )
206
+ num_decode_tokens += num_tokens
207
+ else :
208
+ prefills .append (i )
209
+ num_prefill_tokens += num_tokens
198
210
199
211
# We hope that this is fairly minimal since decodes
200
212
# should be around for a number of iterations so hopefully they are
@@ -284,7 +296,8 @@ def build_dummy(self, num_reqs: int,
284
296
block_table = block_table ,
285
297
seq_lens = seq_lens ,
286
298
seq_lens_list = seq_lens .tolist (),
287
- max_seq_lens = 1 )
299
+ max_seq_lens = 1 ,
300
+ attn_mask = self .runner .spec_attn_mask )
288
301
return self .metadata_cls ( # type: ignore
289
302
num_input_tokens = num_actual_tokens ,
290
303
num_actual_tokens = num_actual_tokens ,
@@ -332,7 +345,7 @@ def build(
332
345
seq_lens = seq_lens_cpu
333
346
max_query_len = query_lens .max ().item ()
334
347
max_seq_lens = seq_lens .max ().item ()
335
- query_start_loc = None
348
+ query_start_loc = common_attn_metadata . query_start_loc
336
349
337
350
prefill_metadata = None
338
351
if self ._num_prefills > 0 :
@@ -397,7 +410,8 @@ def build(
397
410
block_table = block_table ,
398
411
seq_lens = seq_lens ,
399
412
seq_lens_list = seq_lens .tolist (),
400
- max_seq_lens = max_seq_lens )
413
+ max_seq_lens = max_seq_lens ,
414
+ attn_mask = self .runner .spec_attn_mask )
401
415
402
416
return self .metadata_cls ( # type: ignore
403
417
num_actual_tokens = num_actual_tokens ,
@@ -461,6 +475,11 @@ def __init__(
461
475
462
476
ascend_config = get_ascend_config ()
463
477
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
478
+ # Adapt torch air graph mode with spec decoding.
479
+ speculative_config = get_current_vllm_config ().speculative_config
480
+ if speculative_config is not None :
481
+ self .spec_token_num = speculative_config .num_speculative_tokens
482
+ assert self .spec_token_num > 0
464
483
465
484
def _v_up_proj_and_o_proj (self , x ):
466
485
# Convert from (B, N, L) to (N, B, L)
@@ -550,7 +569,10 @@ def _forward_prefill(
550
569
num_tokens = query .size (0 )
551
570
attn_output = None
552
571
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
553
- if attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill :
572
+ if attn_metadata .attn_state in [
573
+ AscendAttentionState .ChunkedPrefill ,
574
+ AscendAttentionState .SpecDecoding
575
+ ]:
554
576
attn_output = torch .empty (num_tokens ,
555
577
self .num_heads * self .v_head_dim ,
556
578
dtype = query .dtype ,
@@ -597,7 +619,7 @@ def _forward_prefill(
597
619
attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
598
620
else :
599
621
raise RuntimeError (
600
- "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
622
+ "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
601
623
)
602
624
attn_output = attn_output .reshape (
603
625
[num_tokens , self .num_heads * self .v_head_dim ])
@@ -670,9 +692,28 @@ def _forward_decode(
670
692
dtype = q .dtype ,
671
693
device = q .device )
672
694
if self .running_in_graph :
673
- # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
674
- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
675
- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
695
+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
696
+ if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
697
+ assert num_tokens % self .spec_token_num == 0
698
+ q_nope = (q_nope .view (
699
+ num_tokens // (self .spec_token_num + 1 ),
700
+ self .spec_token_num + 1 ,
701
+ self .num_heads ,
702
+ - 1 ,
703
+ ).transpose (1 , 2 ).contiguous ())
704
+ q_pe = (q_pe .view (
705
+ num_tokens // (self .spec_token_num + 1 ),
706
+ self .spec_token_num + 1 ,
707
+ self .num_heads ,
708
+ - 1 ,
709
+ ).transpose (1 , 2 ).contiguous ())
710
+ sparse_mode = 3
711
+ spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
712
+ else :
713
+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
714
+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
715
+ sparse_mode = 0
716
+ spec_attn_mask = None
676
717
# shape of knope/k_pe for npu graph mode should be:
677
718
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
678
719
block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -690,7 +731,8 @@ def _forward_decode(
690
731
num_heads = self .num_heads ,
691
732
num_key_value_heads = self .num_kv_heads ,
692
733
input_layout = "BNSD" ,
693
- atten_mask = attn_metadata .attn_mask ,
734
+ atten_mask = spec_attn_mask ,
735
+ sparse_mode = sparse_mode ,
694
736
scale = self .scale ,
695
737
antiquant_mode = 0 ,
696
738
antiquant_scale = None ,
@@ -732,7 +774,9 @@ def forward(
732
774
if attn_metadata is None :
733
775
# Profiling run.
734
776
return output
735
- self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .DecodeOnly
777
+ self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
778
+ AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
779
+ ]
736
780
num_actual_toks = attn_metadata .num_actual_tokens
737
781
if k_pe is None and not self .running_in_graph :
738
782
kv_c , k_pe = self .kv_a_proj_with_mqa (
0 commit comments