Skip to content

Commit 942eb97

Browse files
committed
[BUGFIX] FIX mtp accuraccy when temperture is not 0
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent e1d282d commit 942eb97

File tree

5 files changed

+25
-17
lines changed

5 files changed

+25
-17
lines changed

tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def test_mtp_torchair_correctness(
114114
enforce_eager=False,
115115
additional_config={
116116
"torchair_graph_config": {
117-
"enabled": True
117+
"enabled": True,
118+
"graph_batch_size": [256]
118119
},
119120
"ascend_scheduler_config": {
120121
"enabled": True
@@ -132,7 +133,8 @@ def test_mtp_torchair_correctness(
132133
},
133134
additional_config={
134135
"torchair_graph_config": {
135-
"enabled": True
136+
"enabled": True,
137+
"graph_batch_size": [256]
136138
},
137139
"ascend_scheduler_config": {
138140
"enabled": True

vllm_ascend/attention/mla_v1.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def build_torchair_graph_dummy(
329329
num_reqs, block_table)
330330
num_tokens = num_reqs * self.runner.decode_token_per_req
331331
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
332-
seq_lens_list = seq_lens.tolist()
332+
seq_lens_list = [0] * num_reqs
333333
input_positions = torch.zeros(num_tokens,
334334
dtype=torch.int32,
335335
device=device).long()
@@ -471,7 +471,7 @@ def build(
471471
decode_metadata = None
472472
use_torchair_graph = num_token_pad_size != -1
473473
if self._num_decodes > 0:
474-
actual_seq_q_lens = None
474+
actual_seq_q_lens = query_start_loc[1:].tolist()
475475
max_seq_lens = seq_lens[:self._num_decodes].max().item()
476476
seq_lens = seq_lens[:self._num_decode_tokens]
477477
input_positions = input_positions[:self._num_decode_tokens]
@@ -980,11 +980,13 @@ def _forward_decode(
980980
self.qk_rope_head_dim)
981981
input_layout = "BNSD"
982982

983-
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
984983
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
985984
assert num_tokens % self.spec_token_num == 0
985+
if self.enable_kv_nz:
986+
input_layout = "TND_NTD"
987+
else:
988+
input_layout = "TND"
986989
# [bs * q_seq_len, num_heads_per_rank, dim]
987-
input_layout = "TND"
988990
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
989991
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
990992
sparse_mode = 3

vllm_ascend/sample/rejection_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def sample_recovered_tokens_pytorch(
432432

433433
if IS_NGRAM:
434434
draft_token_id = draft_token_ids[token_idx]
435-
orig_prob = target_probs[token_idx, draft_token_id]
435+
orig_prob = target_probs[token_idx, draft_token_id].item()
436436
target_probs[token_idx, draft_token_id] = 0
437437
prob = target_probs[token_idx].clone()
438438
else:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,7 +1677,7 @@ def _dummy_run(
16771677
**model_kwargs)
16781678
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
16791679
assert isinstance(self.drafter, MtpProposer)
1680-
self.drafter.dummy_run(num_reqs, with_prefill=with_prefill)
1680+
self.drafter.dummy_run(num_reqs)
16811681
return hidden_states
16821682

16831683
@contextmanager
@@ -2140,16 +2140,25 @@ def check_torchair_graph_batch_sizes(self):
21402140
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
21412141
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
21422142

2143+
# we need to make sure that we can deal with max_num_req when `self.decode_token_per_req` is not 1
2144+
self.torchair_graph_batch_sizes = [
2145+
graph_batch_size * self.decode_token_per_req
2146+
for graph_batch_size in self.torchair_graph_batch_sizes
2147+
]
2148+
21432149
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
21442150
tp_size = self.parallel_config.tensor_parallel_size
21452151
if self.parallel_config.enable_expert_parallel:
21462152
new_graph_batch_sizes = []
21472153
for graph_batch_size in self.torchair_graph_batch_sizes:
21482154
cur_graph_batch_size = (graph_batch_size + tp_size -
21492155
1) // tp_size * tp_size
2150-
# `graph_batch_size` need to be divisible by `self.decode_token_per_req`
2151-
cur_graph_batch_size = cur_graph_batch_size * self.decode_token_per_req
21522156
if cur_graph_batch_size not in new_graph_batch_sizes and \
21532157
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
21542158
new_graph_batch_sizes.append(cur_graph_batch_size)
2159+
elif cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
2160+
logger.warning(
2161+
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
2162+
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
2163+
)
21552164
self.torchair_graph_batch_sizes = new_graph_batch_sizes

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,14 +307,9 @@ def load_model(self) -> None:
307307
def dummy_run(
308308
self,
309309
num_tokens: int,
310-
with_prefill: bool = False,
311310
) -> None:
312-
if self.runner.torchair_graph_enabled and not with_prefill:
313-
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
314-
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
315-
else:
316-
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
317-
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
311+
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
312+
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
318313
with set_ascend_forward_context(None,
319314
self.vllm_config,
320315
num_tokens=num_tokens):

0 commit comments

Comments
 (0)