Skip to content

Commit 0ae9ee0

Browse files
authored
[BUGFIX] main-sd-bugfix && [UT] add mtp UT (#593)
### What this PR does / why we need it? The pr will fix some bug about spec decode / MTP The pr add a mtp e2e UT `test_mtp_correctness.py` **vllm_ascend/attention/attention.py** 1. add support `self.attn_mask_cache` only has 1 element to cover scene in which both spec docode and chunked prefill are enabled. **vllm_ascend/distributed/parallel_state.py** 1. remove 2 assert because spec decode worker would use init_worker twice **vllm_ascend/models/deepseek_mtp.py** 1. remove unused params; 2. add support w8a8 in `CustomDeepSeekMTP` **vllm_ascend/quantization/quant_config.py** 1. use `AscendUnquantizedFusedMoEMethod` instead of `UnquantizedFusedMoEMethod` **other** 1. replace `from vllm.logger import init_logger` to `from vllm.logger import logger` all of the vllm-ascend project ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent 5442b46 commit 0ae9ee0

File tree

10 files changed

+375
-31
lines changed

10 files changed

+375
-31
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ jobs:
161161
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true'
162162
run: |
163163
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
164-
pytest -sv tests/singlecard/spec_decode
164+
pytest -sv tests/singlecard/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
165+
pytest -sv tests/singlecard/spec_decode --ignore=tests/singlecard/spec_decode/e2e/test_mtp_correctness.py
165166
fi
166167
167168
- name: Run vllm-project/vllm test for V0 Engine
Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mtp_correctness.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
"""This docstring details important information on the testing methodology.
20+
21+
Most of the tests rely on "greedy equality", where we expect the output of
22+
speculative decoding on a sequence to exactly match the output of normal non-
23+
speculative decoding.
24+
25+
Since speculative decoding with rejection sampling guarantees that the output
26+
distribution matches the target model's output distribution (up to hardware
27+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
28+
equality.
29+
30+
However, we still need to verify below scenario could be passed:
31+
* Batch size 1 greedy equality
32+
* Batch size >1 greedy equality
33+
* Test greedy equality under preemption
34+
* Test greedy equality under various number of speculative tokens.
35+
36+
With those tests, we can say at least, mtp would not break the
37+
correctess for the target model outputs.
38+
"""
39+
40+
import pytest
41+
42+
from .conftest import run_equality_correctness_test
43+
44+
# main model
45+
# NOTE vLLM use fp8 model, vllm-ascend use bf16 model
46+
MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
47+
48+
# max. number of speculative tokens: this corresponds to
49+
# num_nextn_predict_layers in the config.json of the speculator model.
50+
MAX_SPEC_TOKENS = 1
51+
52+
# precision
53+
PRECISION = "bfloat16"
54+
55+
56+
@pytest.mark.parametrize(
57+
"common_llm_kwargs",
58+
[{
59+
# Skip cuda graph recording for fast test.
60+
"enforce_eager": True,
61+
62+
# Print spec metrics.
63+
"disable_log_stats": False,
64+
65+
# Precision
66+
"dtype": PRECISION,
67+
68+
# Main model
69+
"model_name": MAIN_MODEL,
70+
71+
# GPU memory utilization
72+
"gpu_memory_utilization": 0.85
73+
}])
74+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
75+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
76+
@pytest.mark.parametrize("test_llm_kwargs", [
77+
{
78+
"speculative_config": {
79+
"num_speculative_tokens": MAX_SPEC_TOKENS,
80+
},
81+
},
82+
])
83+
@pytest.mark.parametrize("output_len", [
84+
128,
85+
])
86+
@pytest.mark.parametrize("batch_size", [1, 32])
87+
@pytest.mark.parametrize("seed", [1])
88+
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
89+
per_test_common_llm_kwargs,
90+
baseline_llm_kwargs, test_llm_kwargs,
91+
batch_size: int, output_len: int,
92+
seed: int):
93+
94+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
95+
per_test_common_llm_kwargs,
96+
baseline_llm_kwargs, test_llm_kwargs,
97+
batch_size, output_len, seed)
98+
99+
100+
@pytest.mark.parametrize(
101+
"common_llm_kwargs",
102+
[{
103+
# Skip cuda graph recording for fast test.
104+
"enforce_eager": True,
105+
106+
# Print spec metrics.
107+
"disable_log_stats": False,
108+
109+
# Precision
110+
"dtype": PRECISION,
111+
112+
# Main model
113+
"model_name": MAIN_MODEL,
114+
115+
# GPU memory utilization
116+
"gpu_memory_utilization": 0.85
117+
}])
118+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
119+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
120+
@pytest.mark.parametrize("test_llm_kwargs", [
121+
{
122+
"speculative_config": {
123+
"num_speculative_tokens": MAX_SPEC_TOKENS,
124+
"disable_logprobs": False,
125+
},
126+
},
127+
{
128+
"speculative_config": {
129+
"num_speculative_tokens": MAX_SPEC_TOKENS,
130+
"disable_logprobs": True,
131+
},
132+
},
133+
])
134+
@pytest.mark.parametrize("output_len", [
135+
128,
136+
])
137+
@pytest.mark.parametrize("batch_size", [8])
138+
@pytest.mark.parametrize("seed", [1])
139+
@pytest.mark.parametrize("logprobs", [1, 6])
140+
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
141+
per_test_common_llm_kwargs,
142+
baseline_llm_kwargs, test_llm_kwargs,
143+
batch_size: int, output_len: int, seed: int,
144+
logprobs: int):
145+
146+
run_equality_correctness_test(
147+
vllm_runner,
148+
common_llm_kwargs,
149+
per_test_common_llm_kwargs,
150+
baseline_llm_kwargs,
151+
test_llm_kwargs,
152+
batch_size,
153+
output_len,
154+
seed,
155+
logprobs=logprobs,
156+
prompt_logprobs=logprobs,
157+
disable_logprobs=test_llm_kwargs["speculative_config"]
158+
["disable_logprobs"])
159+
160+
161+
@pytest.mark.skipif(
162+
True,
163+
reason=
164+
"Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"
165+
)
166+
@pytest.mark.parametrize(
167+
"common_llm_kwargs",
168+
[{
169+
"enforce_eager": False,
170+
171+
# Print spec metrics.
172+
"disable_log_stats": False,
173+
174+
# Precision
175+
"dtype": PRECISION,
176+
177+
# Main model
178+
"model_name": MAIN_MODEL,
179+
"gpu_memory_utilization": 0.85
180+
}])
181+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
182+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
183+
@pytest.mark.parametrize("test_llm_kwargs", [
184+
{
185+
"speculative_config": {
186+
"num_speculative_tokens": MAX_SPEC_TOKENS,
187+
},
188+
},
189+
])
190+
@pytest.mark.parametrize("output_len", [
191+
128,
192+
])
193+
@pytest.mark.parametrize("batch_size", [1, 32])
194+
@pytest.mark.parametrize("seed", [1])
195+
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
196+
per_test_common_llm_kwargs,
197+
baseline_llm_kwargs,
198+
test_llm_kwargs,
199+
batch_size: int,
200+
output_len: int, seed: int):
201+
"""Verify greedy equality with cuda graph enabled and different
202+
batch sizes."""
203+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
204+
per_test_common_llm_kwargs,
205+
baseline_llm_kwargs, test_llm_kwargs,
206+
batch_size, output_len, seed)
207+
208+
209+
@pytest.mark.parametrize(
210+
"common_llm_kwargs",
211+
[{
212+
"block_size": 8,
213+
# 2 for small prompt, 256//8 for generated.
214+
"num_gpu_blocks_override": 2 + 256 // 8,
215+
"max_model_len": (2 + 256 // 8) * 8,
216+
217+
# Skip cuda graph recording for fast test.
218+
"enforce_eager": True,
219+
220+
# Precision
221+
"dtype": PRECISION,
222+
223+
# Main model
224+
"model_name": MAIN_MODEL,
225+
226+
# GPU memory utilization
227+
"gpu_memory_utilization": 0.9
228+
}])
229+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
230+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
231+
@pytest.mark.parametrize("test_llm_kwargs", [
232+
{
233+
"speculative_config": {
234+
"num_speculative_tokens": MAX_SPEC_TOKENS,
235+
},
236+
},
237+
])
238+
@pytest.mark.parametrize(
239+
"output_len",
240+
[
241+
# Use small output len for fast test.
242+
128,
243+
])
244+
@pytest.mark.parametrize("batch_size", [4])
245+
@pytest.mark.parametrize("seed", [1])
246+
def test_mtp_e2e_greedy_correctness_with_preemption(
247+
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
248+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
249+
seed: int):
250+
"""Verify greedy equality, even when some sequences are preempted mid-
251+
generation.
252+
"""
253+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
254+
per_test_common_llm_kwargs,
255+
baseline_llm_kwargs, test_llm_kwargs,
256+
batch_size, output_len, seed)
257+
258+
259+
@pytest.mark.parametrize(
260+
"common_llm_kwargs",
261+
[{
262+
# Skip cuda graph recording for fast test.
263+
"enforce_eager": True,
264+
265+
# Precision
266+
"dtype": PRECISION,
267+
268+
# Main model
269+
"model_name": MAIN_MODEL,
270+
271+
# GPU memory utilization
272+
"gpu_memory_utilization": 0.9
273+
}])
274+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
275+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
276+
@pytest.mark.parametrize(
277+
"test_llm_kwargs",
278+
[
279+
{
280+
"speculative_config": {
281+
"num_speculative_tokens": k,
282+
},
283+
}
284+
# Try a range of num. speculative tokens
285+
for k in range(1, 1 + MAX_SPEC_TOKENS)
286+
])
287+
@pytest.mark.parametrize("batch_size", [2])
288+
@pytest.mark.parametrize(
289+
"output_len",
290+
[
291+
# Use smaller output len for fast test.
292+
32,
293+
])
294+
@pytest.mark.parametrize("seed", [1])
295+
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
296+
per_test_common_llm_kwargs, baseline_llm_kwargs,
297+
test_llm_kwargs, batch_size: int, output_len: int,
298+
seed: int):
299+
"""Verify that mtp speculative decoding produces exact equality
300+
to without spec decode with different values of num_speculative_tokens.
301+
"""
302+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
303+
per_test_common_llm_kwargs,
304+
baseline_llm_kwargs, test_llm_kwargs,
305+
batch_size, output_len, seed)
306+
307+
308+
@pytest.mark.parametrize(
309+
"common_llm_kwargs",
310+
[{
311+
# Skip cuda graph recording for fast test.
312+
"enforce_eager": True,
313+
314+
# Precision
315+
"dtype": PRECISION,
316+
317+
# Main model
318+
"model_name": MAIN_MODEL,
319+
320+
# GPU memory utilization
321+
"gpu_memory_utilization": 0.9
322+
}])
323+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
324+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
325+
@pytest.mark.parametrize("test_llm_kwargs", [{
326+
"speculative_config": {
327+
"num_speculative_tokens": MAX_SPEC_TOKENS,
328+
"disable_by_batch_size": 4
329+
},
330+
}])
331+
@pytest.mark.parametrize("batch_size", [1, 5])
332+
@pytest.mark.parametrize(
333+
"output_len",
334+
[
335+
# Use smaller output len for fast test.
336+
32,
337+
])
338+
@pytest.mark.parametrize("seed", [1])
339+
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
340+
per_test_common_llm_kwargs, baseline_llm_kwargs,
341+
test_llm_kwargs, batch_size: int, output_len: int,
342+
seed: int):
343+
"""Verify that mtp speculative decoding produces exact equality
344+
to without spec decode when speculation is disabled for large
345+
batch sizes.
346+
"""
347+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
348+
per_test_common_llm_kwargs,
349+
baseline_llm_kwargs, test_llm_kwargs,
350+
batch_size, output_len, seed)
351+
352+
353+
if __name__ == "__main__":
354+
import pytest
355+
pytest.main([__file__])

vllm_ascend/attention/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def get_splitfuse_attn_mask(
113113
self.update_attn_cache(max_seq_len, dtype, device)
114114
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
115115
# is not the same. Fix this in the future when kernel is ready.
116-
if self.attn_mask_cache[0][1] > 0:
116+
if self.attn_mask_cache.numel(
117+
) > 1 and self.attn_mask_cache[0][1] > 0:
117118
attn_mask = self.get_attn_mask( # type: ignore
118119
max_seq_len, dtype, device)
119120
attn_mask *= -10000

vllm_ascend/attention/mla_v1.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
77
AttentionMetadata,
88
MLAAttentionImpl)
9-
from vllm.logger import init_logger
109
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1110
LinearBase, RowParallelLinear,
1211
UnquantizedLinearMethod)
@@ -21,8 +20,6 @@
2120
from vllm.v1.core.sched.output import SchedulerOutput
2221
from vllm.v1.worker.gpu_input_batch import InputBatch
2322

24-
logger = init_logger(__name__)
25-
2623

2724
class AscendMLABackend(AttentionBackend):
2825

0 commit comments

Comments
 (0)