Skip to content

Commit 10df64c

Browse files
authored
[0.9.1][bugfix] fix chunked_prefill_mla input for MTP (#1473)
### What this PR does / why we need it? fix chunked_prefill_mla output for MTP ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? --------- Signed-off-by: underfituu <hzhucong@163.com>
1 parent e878d56 commit 10df64c

File tree

4 files changed

+44
-29
lines changed

4 files changed

+44
-29
lines changed

.github/workflows/vllm_ascend_test_long_term.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ jobs:
9595
run: |
9696
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
9797
# v0 spec decode test
98-
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process
99-
pytest -sv tests/long_term/spec_decode_v0 --ignore=tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py
98+
# VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process
99+
# pytest -sv tests/long_term/spec_decode_v0 --ignore=tests/long_term/spec_decode_v0/e2e/test_mtp_correctness.py
100100
# v1 spec decode test
101101
# TODO: revert me when test_v1_mtp_correctness.py is fixed
102-
# VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py
102+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py
103103
# TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed
104104
# VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode_v1/test_v1_spec_decode.py
105105
# accuracy test single card

tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def test_mtp_correctness(
6363
with monkeypatch.context() as m:
6464
m.setenv("VLLM_USE_V1", "1")
6565

66-
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
66+
ref_llm = LLM(model=model_name,
67+
max_model_len=256,
68+
gpu_memory_utilization=0.8,
69+
enforce_eager=True)
6770
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
6871
del ref_llm
6972

@@ -74,6 +77,7 @@ def test_mtp_correctness(
7477
"num_speculative_tokens": 1,
7578
},
7679
max_model_len=256,
80+
gpu_memory_utilization=0.8,
7781
enforce_eager=True)
7882
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
7983
matches = 0

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,6 +2041,7 @@ def _generate_mtp_token_ids(
20412041
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
20422042
attn_metadata.query_start_loc,
20432043
num_rejected_tokens,
2044+
force_one_token=True,
20442045
)
20452046
target_token_ids = self.input_ids[token_indices]
20462047
target_positions = positions[token_indices]

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ def __init__(
6464

6565
@staticmethod
6666
def prepare_inputs(
67-
# [batch_size + 1]
68-
cu_target_query_lens: torch.Tensor,
69-
# [batch_size]
70-
num_rejected_tokens: torch.Tensor,
67+
# [batch_size + 1]
68+
cu_target_query_lens: torch.Tensor,
69+
# [batch_size]
70+
num_rejected_tokens: torch.Tensor,
71+
force_one_token: bool = False
7172
) -> tuple[torch.Tensor, torch.Tensor]:
7273
# cu_target_query_lens: [0, a, a + b, a + b + c]
7374
# num_rejected_tokens: [n1, n2, n3]
@@ -76,32 +77,39 @@ def prepare_inputs(
7677
# token_indices: [0, 1, ..., a - n1 - 1,
7778
# a, a + 1, ..., a + b - n2 - 1,
7879
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
79-
8080
# [0, a, a + b, a + b + c] -> [a, b, c]
8181
query_len_per_req = (cu_target_query_lens[1:] -
8282
cu_target_query_lens[:-1])
8383
# [a, b, c] -> [a - n1, b - n2, c - n3]
8484
num_tokens_per_req = query_len_per_req - num_rejected_tokens
85+
if force_one_token:
86+
# enable force_one_token means we only focus on the last token position of each request
87+
# token_indices: [batch_size]
88+
cu_num_tokens = torch.arange(cu_target_query_lens.size(0),
89+
device=cu_target_query_lens.device,
90+
dtype=torch.int32)
91+
relative_index = query_len_per_req - num_rejected_tokens - 1
92+
token_indices = cu_target_query_lens[:-1] + relative_index
93+
else:
94+
cu_num_tokens = torch.empty_like(cu_target_query_lens)
95+
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
96+
cu_num_tokens[0] = 0
97+
98+
# FIXME(woosuk): Avoid synchronization.
99+
num_tokens = cu_num_tokens[-1].item()
100+
token_indices = torch.empty(
101+
num_tokens,
102+
dtype=torch.int32,
103+
device=cu_num_tokens.device,
104+
)
85105

86-
cu_num_tokens = torch.empty_like(cu_target_query_lens)
87-
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
88-
cu_num_tokens[0] = 0
89-
90-
# FIXME(woosuk): Avoid synchronization.
91-
num_tokens = cu_num_tokens[-1].item()
92-
token_indices = torch.empty(
93-
num_tokens,
94-
dtype=torch.int32,
95-
device=cu_num_tokens.device,
96-
)
97-
98-
BLOCK_SIZE = 1024
99-
prepare_input_kernel(
100-
token_indices,
101-
cu_target_query_lens,
102-
cu_num_tokens,
103-
block_size=BLOCK_SIZE,
104-
)
106+
BLOCK_SIZE = 1024
107+
prepare_input_kernel(
108+
token_indices,
109+
cu_target_query_lens,
110+
cu_num_tokens,
111+
block_size=BLOCK_SIZE,
112+
)
105113
return cu_num_tokens, token_indices
106114

107115
def propose(
@@ -160,7 +168,9 @@ def propose(
160168
common_prefix_len=0,
161169
common_attn_metadata=common_attn_metadata,
162170
)
163-
171+
# When proposing, we set the prefill query_lens to 1.
172+
if attn_metadata.prefill is not None:
173+
attn_metadata.prefill.query_lens[:] = 1
164174
with set_ascend_forward_context(attn_metadata, self.vllm_config):
165175
hidden_states = self.model(
166176
input_ids=input_ids,

0 commit comments

Comments
 (0)