Skip to content

Commit da2d5ac

Browse files
authored
[BUGFIX] [v0.9.1] Fix mtp with disaggregated-prefill (#1694)
### What this PR does / why we need it? [BUGFIX] [v0.9.1] Fix mtp with disaggregated-prefill ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent 33dbe57 commit da2d5ac

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,11 @@ def build(
538538
actual_seq_q_lens = query_start_loc[1:].tolist(
539539
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
540540
num_reqs_pad_size]
541+
# mtp torchair + PD scenario, last element of actual_seq_q_lens must equal to num_reqs_pad_size
542+
num_padded_token_size = slot_mapping.size(0)
543+
if actual_seq_q_lens[-1] != num_padded_token_size:
544+
actual_seq_q_lens.append(num_padded_token_size)
545+
seq_lens_list.append(0)
541546
else:
542547
seq_lens_list = seq_lens.tolist()
543548

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,9 @@ def _dummy_run(
16911691
torch._dynamo.mark_static(
16921692
get_forward_context().mc2_mask)
16931693
torch._dynamo.mark_static(attn_metadata.slot_mapping)
1694+
if self.speculative_config:
1695+
torch._dynamo.mark_static(
1696+
attn_metadata.decode.attn_mask)
16941697
for kv in self.kv_caches:
16951698
assert isinstance(
16961699
kv, tuple), "kv_cache must be a tuple"
@@ -1720,7 +1723,7 @@ def _dummy_run(
17201723
**model_kwargs)
17211724
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
17221725
assert isinstance(self.drafter, MtpProposer)
1723-
self.drafter.dummy_run(num_reqs)
1726+
self.drafter.dummy_run(num_reqs, with_prefill=with_prefill)
17241727
return hidden_states
17251728

17261729
@contextmanager

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ def propose(
211211

212212
with set_ascend_forward_context(attn_metadata,
213213
self.vllm_config,
214-
num_tokens=num_input_tokens):
214+
num_tokens=num_input_tokens,
215+
with_prefill=self.runner.with_prefill):
215216
with ProfileExecuteDuration().capture_async('mtp_forward'):
216217
model_kwargs = {}
217218
model_kwargs["attn_metadata"] = attn_metadata
@@ -305,15 +306,13 @@ def load_model(self) -> None:
305306
ge_cache=False)
306307

307308
@torch.inference_mode()
308-
def dummy_run(
309-
self,
310-
num_tokens: int,
311-
) -> None:
309+
def dummy_run(self, num_tokens: int, with_prefill: bool = False) -> None:
312310
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
313311
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
314312
with set_ascend_forward_context(None,
315313
self.vllm_config,
316-
num_tokens=num_tokens):
314+
num_tokens=num_tokens,
315+
with_prefill=with_prefill):
317316
self.model(input_ids=self.input_ids[:num_tokens],
318317
positions=self.positions[:num_tokens],
319318
previous_hidden_states=self.hidden_states[:num_tokens],

0 commit comments

Comments
 (0)