File tree Expand file tree Collapse file tree 3 files changed +14
-7
lines changed Expand file tree Collapse file tree 3 files changed +14
-7
lines changed Original file line number Diff line number Diff line change @@ -538,6 +538,11 @@ def build(
538
538
actual_seq_q_lens = query_start_loc [1 :].tolist (
539
539
) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
540
540
num_reqs_pad_size ]
541
+ # mtp torchair + PD scenario, last element of actual_seq_q_lens must equal to num_reqs_pad_size
542
+ num_padded_token_size = slot_mapping .size (0 )
543
+ if actual_seq_q_lens [- 1 ] != num_padded_token_size :
544
+ actual_seq_q_lens .append (num_padded_token_size )
545
+ seq_lens_list .append (0 )
541
546
else :
542
547
seq_lens_list = seq_lens .tolist ()
543
548
Original file line number Diff line number Diff line change @@ -1691,6 +1691,9 @@ def _dummy_run(
1691
1691
torch ._dynamo .mark_static (
1692
1692
get_forward_context ().mc2_mask )
1693
1693
torch ._dynamo .mark_static (attn_metadata .slot_mapping )
1694
+ if self .speculative_config :
1695
+ torch ._dynamo .mark_static (
1696
+ attn_metadata .decode .attn_mask )
1694
1697
for kv in self .kv_caches :
1695
1698
assert isinstance (
1696
1699
kv , tuple ), "kv_cache must be a tuple"
@@ -1720,7 +1723,7 @@ def _dummy_run(
1720
1723
** model_kwargs )
1721
1724
if self .speculative_config and self .speculative_config .method == "deepseek_mtp" :
1722
1725
assert isinstance (self .drafter , MtpProposer )
1723
- self .drafter .dummy_run (num_reqs )
1726
+ self .drafter .dummy_run (num_reqs , with_prefill = with_prefill )
1724
1727
return hidden_states
1725
1728
1726
1729
@contextmanager
Original file line number Diff line number Diff line change @@ -211,7 +211,8 @@ def propose(
211
211
212
212
with set_ascend_forward_context (attn_metadata ,
213
213
self .vllm_config ,
214
- num_tokens = num_input_tokens ):
214
+ num_tokens = num_input_tokens ,
215
+ with_prefill = self .runner .with_prefill ):
215
216
with ProfileExecuteDuration ().capture_async ('mtp_forward' ):
216
217
model_kwargs = {}
217
218
model_kwargs ["attn_metadata" ] = attn_metadata
@@ -305,15 +306,13 @@ def load_model(self) -> None:
305
306
ge_cache = False )
306
307
307
308
@torch .inference_mode ()
308
- def dummy_run (
309
- self ,
310
- num_tokens : int ,
311
- ) -> None :
309
+ def dummy_run (self , num_tokens : int , with_prefill : bool = False ) -> None :
312
310
attn_metadata = self .runner .attn_metadata_builder .build_torchair_graph_dummy (
313
311
num_reqs = num_tokens , num_actual_tokens = 1 , is_mtp_model = True )
314
312
with set_ascend_forward_context (None ,
315
313
self .vllm_config ,
316
- num_tokens = num_tokens ):
314
+ num_tokens = num_tokens ,
315
+ with_prefill = with_prefill ):
317
316
self .model (input_ids = self .input_ids [:num_tokens ],
318
317
positions = self .positions [:num_tokens ],
319
318
previous_hidden_states = self .hidden_states [:num_tokens ],
You can’t perform that action at this time.
0 commit comments