Skip to content

[BUGFIX] FIX mtp accuraccy when temperture is not 0 #1632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def test_mtp_torchair_correctness(
enforce_eager=False,
additional_config={
"torchair_graph_config": {
"enabled": True
"enabled": True,
"graph_batch_size": [256]
},
"ascend_scheduler_config": {
"enabled": True
Expand All @@ -132,7 +133,8 @@ def test_mtp_torchair_correctness(
},
additional_config={
"torchair_graph_config": {
"enabled": True
"enabled": True,
"graph_batch_size": [256]
},
"ascend_scheduler_config": {
"enabled": True
Expand Down
10 changes: 6 additions & 4 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def build_torchair_graph_dummy(
num_reqs, block_table)
num_tokens = num_reqs * self.runner.decode_token_per_req
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
seq_lens_list = seq_lens.tolist()
seq_lens_list = [0] * num_reqs
input_positions = torch.zeros(num_tokens,
dtype=torch.int32,
device=device).long()
Expand Down Expand Up @@ -471,7 +471,7 @@ def build(
decode_metadata = None
use_torchair_graph = num_token_pad_size != -1
if self._num_decodes > 0:
actual_seq_q_lens = None
actual_seq_q_lens = query_start_loc[1:].tolist()
max_seq_lens = seq_lens[:self._num_decodes].max().item()
seq_lens = seq_lens[:self._num_decode_tokens]
input_positions = input_positions[:self._num_decode_tokens]
Expand Down Expand Up @@ -980,11 +980,13 @@ def _forward_decode(
self.qk_rope_head_dim)
input_layout = "BNSD"

# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
if self.enable_kv_nz:
input_layout = "TND_NTD"
else:
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
input_layout = "TND"
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
sparse_mode = 3
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def sample_recovered_tokens_pytorch(

if IS_NGRAM:
draft_token_id = draft_token_ids[token_idx]
orig_prob = target_probs[token_idx, draft_token_id]
orig_prob = target_probs[token_idx, draft_token_id].item()
target_probs[token_idx, draft_token_id] = 0
prob = target_probs[token_idx].clone()
else:
Expand Down
16 changes: 13 additions & 3 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ def _dummy_run(
**model_kwargs)
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
assert isinstance(self.drafter, MtpProposer)
self.drafter.dummy_run(num_reqs, with_prefill=with_prefill)
self.drafter.dummy_run(num_reqs)
return hidden_states

@contextmanager
Expand Down Expand Up @@ -2140,16 +2140,26 @@ def check_torchair_graph_batch_sizes(self):
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
self.torchair_graph_batch_sizes.append(self.max_num_reqs)

# we need to make sure that we can deal with max_num_req when `self.decode_token_per_req` is not 1
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]

# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
tp_size = self.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel:
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
# `graph_batch_size` need to be divisible by `self.decode_token_per_req`
cur_graph_batch_size = cur_graph_batch_size * self.decode_token_per_req
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
9 changes: 2 additions & 7 deletions vllm_ascend/worker/mtp_proposer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,14 +307,9 @@ def load_model(self) -> None:
def dummy_run(
self,
num_tokens: int,
with_prefill: bool = False,
) -> None:
if self.runner.torchair_graph_enabled and not with_prefill:
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
else:
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True)
with set_ascend_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
Expand Down