Skip to content

Commit e99d232

Browse files
authored
[BUGFIX] FIX mtp accuraccy when temperture is not 0 (#1632)
### What this PR does / why we need it? 1. [BUGFIX] FIX mtp accuraccy when temperture is not 0 2. [BUGFIX] FIX mtp when multi DP is enabled ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? vllm-ascend/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent ea3dc31 commit e99d232

File tree

5 files changed

+26
-17
lines changed

5 files changed

+26
-17
lines changed

tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def test_mtp_torchair_correctness(
114114
enforce_eager=False,
115115
additional_config={
116116
"torchair_graph_config": {
117-
"enabled": True
117+
"enabled": True,
118+
"graph_batch_size": [256]
118119
},
119120
"ascend_scheduler_config": {
120121
"enabled": True
@@ -132,7 +133,8 @@ def test_mtp_torchair_correctness(
132133
},
133134
additional_config={
134135
"torchair_graph_config": {
135-
"enabled": True
136+
"enabled": True,
137+
"graph_batch_size": [256]
136138
},
137139
"ascend_scheduler_config": {
138140
"enabled": True

vllm_ascend/attention/mla_v1.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def build_torchair_graph_dummy(
324324
num_reqs, block_table)
325325
num_tokens = num_reqs * self.runner.decode_token_per_req
326326
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
327-
seq_lens_list = seq_lens.tolist()
327+
seq_lens_list = [0] * num_reqs
328328
input_positions = torch.zeros(num_tokens,
329329
dtype=torch.int32,
330330
device=device).long()
@@ -497,7 +497,7 @@ def build(
497497
decode_metadata = None
498498
use_torchair_graph = num_token_pad_size != -1
499499
if self._num_decodes > 0:
500-
actual_seq_q_lens = None
500+
actual_seq_q_lens = query_start_loc[1:].tolist()
501501
max_seq_lens = seq_lens[:self._num_decodes].max().item()
502502
seq_lens = seq_lens[:self._num_decode_tokens]
503503
input_positions = input_positions[:self._num_decode_tokens]
@@ -1014,11 +1014,13 @@ def _forward_decode(
10141014
self.qk_rope_head_dim)
10151015
input_layout = "BNSD"
10161016

1017-
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
10181017
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
10191018
assert num_tokens % self.spec_token_num == 0
1019+
if self.enable_kv_nz:
1020+
input_layout = "TND_NTD"
1021+
else:
1022+
input_layout = "TND"
10201023
# [bs * q_seq_len, num_heads_per_rank, dim]
1021-
input_layout = "TND"
10221024
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
10231025
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
10241026
sparse_mode = 3

vllm_ascend/sample/rejection_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def sample_recovered_tokens_pytorch(
432432

433433
if IS_NGRAM:
434434
draft_token_id = draft_token_ids[token_idx]
435-
orig_prob = target_probs[token_idx, draft_token_id]
435+
orig_prob = target_probs[token_idx, draft_token_id].item()
436436
target_probs[token_idx, draft_token_id] = 0
437437
prob = target_probs[token_idx].clone()
438438
else:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,7 +1698,7 @@ def _dummy_run(
16981698
**model_kwargs)
16991699
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
17001700
assert isinstance(self.drafter, MtpProposer)
1701-
self.drafter.dummy_run(num_reqs, with_prefill=with_prefill)
1701+
self.drafter.dummy_run(num_reqs)
17021702
return hidden_states
17031703

17041704
@contextmanager
@@ -2163,16 +2163,26 @@ def check_torchair_graph_batch_sizes(self):
21632163
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
21642164
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
21652165

2166+
# we need to make sure that we can deal with max_num_req when `self.decode_token_per_req` is not 1
2167+
self.torchair_graph_batch_sizes = [
2168+
graph_batch_size * self.decode_token_per_req
2169+
for graph_batch_size in self.torchair_graph_batch_sizes
2170+
]
2171+
21662172
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
21672173
tp_size = self.parallel_config.tensor_parallel_size
21682174
if self.parallel_config.enable_expert_parallel:
21692175
new_graph_batch_sizes = []
21702176
for graph_batch_size in self.torchair_graph_batch_sizes:
21712177
cur_graph_batch_size = (graph_batch_size + tp_size -
21722178
1) // tp_size * tp_size
2173-
# `graph_batch_size` need to be divisible by `self.decode_token_per_req`
2174-
cur_graph_batch_size = cur_graph_batch_size * self.decode_token_per_req
21752179
if cur_graph_batch_size not in new_graph_batch_sizes and \
21762180
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
21772181
new_graph_batch_sizes.append(cur_graph_batch_size)
2182+
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
2183+
and self.decode_token_per_req > 1:
2184+
logger.warning(
2185+
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
2186+
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
2187+
)
21782188
self.torchair_graph_batch_sizes = new_graph_batch_sizes

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,9 @@ def load_model(self) -> None:
308308
def dummy_run(
309309
self,
310310
num_tokens: int,
311-
with_prefill: bool = False,
312311
) -> None:
313-
if self.runner.torchair_graph_enabled and not with_prefill:
314-
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
315-
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
316-
else:
317-
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
318-
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
312+
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
313+
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
319314
with set_ascend_forward_context(None,
320315
self.vllm_config,
321316
num_tokens=num_tokens):

0 commit comments

Comments
 (0)