Skip to content

Commit abc0c86

Browse files
JC-ut0wanghanqingLYT
authored andcommitted
[BUGFIX] FIX mtp accuraccy when temperture is not 0 (vllm-project#1632)
### What this PR does / why we need it? 1. [BUGFIX] FIX mtp accuraccy when temperture is not 0 2. [BUGFIX] FIX mtp when multi DP is enabled ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? vllm-ascend/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent f392ebf commit abc0c86

File tree

5 files changed

+26
-17
lines changed

5 files changed

+26
-17
lines changed

tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def test_mtp_torchair_correctness(
114114
enforce_eager=False,
115115
additional_config={
116116
"torchair_graph_config": {
117-
"enabled": True
117+
"enabled": True,
118+
"graph_batch_size": [256]
118119
},
119120
"ascend_scheduler_config": {
120121
"enabled": True
@@ -132,7 +133,8 @@ def test_mtp_torchair_correctness(
132133
},
133134
additional_config={
134135
"torchair_graph_config": {
135-
"enabled": True
136+
"enabled": True,
137+
"graph_batch_size": [256]
136138
},
137139
"ascend_scheduler_config": {
138140
"enabled": True

vllm_ascend/attention/mla_v1.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def build_torchair_graph_dummy(
324324
num_reqs, block_table)
325325
num_tokens = num_reqs * self.runner.decode_token_per_req
326326
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
327-
seq_lens_list = seq_lens.tolist()
327+
seq_lens_list = [0] * num_reqs
328328
input_positions = torch.zeros(num_tokens,
329329
dtype=torch.int32,
330330
device=device).long()
@@ -497,7 +497,7 @@ def build(
497497
decode_metadata = None
498498
use_torchair_graph = num_token_pad_size != -1
499499
if self._num_decodes > 0:
500-
actual_seq_q_lens = None
500+
actual_seq_q_lens = query_start_loc[1:].tolist()
501501
max_seq_lens = seq_lens[:self._num_decodes].max().item()
502502
seq_lens = seq_lens[:self._num_decode_tokens]
503503
input_positions = input_positions[:self._num_decode_tokens]
@@ -1014,11 +1014,13 @@ def _forward_decode(
10141014
self.qk_rope_head_dim)
10151015
input_layout = "BNSD"
10161016

1017-
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
10181017
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
10191018
assert num_tokens % self.spec_token_num == 0
1019+
if self.enable_kv_nz:
1020+
input_layout = "TND_NTD"
1021+
else:
1022+
input_layout = "TND"
10201023
# [bs * q_seq_len, num_heads_per_rank, dim]
1021-
input_layout = "TND"
10221024
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
10231025
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
10241026
sparse_mode = 3

vllm_ascend/sample/rejection_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def sample_recovered_tokens_pytorch(
432432

433433
if IS_NGRAM:
434434
draft_token_id = draft_token_ids[token_idx]
435-
orig_prob = target_probs[token_idx, draft_token_id]
435+
orig_prob = target_probs[token_idx, draft_token_id].item()
436436
target_probs[token_idx, draft_token_id] = 0
437437
prob = target_probs[token_idx].clone()
438438
else:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,7 +1728,7 @@ def _dummy_run(
17281728
self.eplb_updator.forward_end()
17291729
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
17301730
assert isinstance(self.drafter, MtpProposer)
1731-
self.drafter.dummy_run(num_reqs, with_prefill=with_prefill)
1731+
self.drafter.dummy_run(num_reqs)
17321732
return hidden_states
17331733

17341734
@contextmanager
@@ -2207,16 +2207,26 @@ def check_torchair_graph_batch_sizes(self):
22072207
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
22082208
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
22092209

2210+
# we need to make sure that we can deal with max_num_req when `self.decode_token_per_req` is not 1
2211+
self.torchair_graph_batch_sizes = [
2212+
graph_batch_size * self.decode_token_per_req
2213+
for graph_batch_size in self.torchair_graph_batch_sizes
2214+
]
2215+
22102216
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
22112217
tp_size = self.parallel_config.tensor_parallel_size
22122218
if self.parallel_config.enable_expert_parallel:
22132219
new_graph_batch_sizes = []
22142220
for graph_batch_size in self.torchair_graph_batch_sizes:
22152221
cur_graph_batch_size = (graph_batch_size + tp_size -
22162222
1) // tp_size * tp_size
2217-
# `graph_batch_size` need to be divisible by `self.decode_token_per_req`
2218-
cur_graph_batch_size = cur_graph_batch_size * self.decode_token_per_req
22192223
if cur_graph_batch_size not in new_graph_batch_sizes and \
22202224
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
22212225
new_graph_batch_sizes.append(cur_graph_batch_size)
2226+
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
2227+
and self.decode_token_per_req > 1:
2228+
logger.warning(
2229+
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
2230+
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
2231+
)
22222232
self.torchair_graph_batch_sizes = new_graph_batch_sizes

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,9 @@ def load_model(self) -> None:
308308
def dummy_run(
309309
self,
310310
num_tokens: int,
311-
with_prefill: bool = False,
312311
) -> None:
313-
if self.runner.torchair_graph_enabled and not with_prefill:
314-
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
315-
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
316-
else:
317-
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
318-
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
312+
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
313+
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
319314
with set_ascend_forward_context(None,
320315
self.vllm_config,
321316
num_tokens=num_tokens):

0 commit comments

Comments
 (0)