Skip to content

Commit 339d689

Browse files
authored
[CI/UT][bugfix] fix v0 spec decode (#1321)
### What this PR does / why we need it? 1. [PR913](#913) introduced an error that caused V0's spec decode function to fail. [PR1109](#1109) wanted to fix this problem. Unfortunately, the fix broke the ngram function. I fixed the ngram function in this PR. **PS**: Q: Why is there a problem when ngram is not found when pr1109 is merged? A: The newly introduced problem will only appear when tp>1, and the use cases on CI are all tp=1 2. In versions after 0.7.3, vllm-ascend deleted some spec decode UTs to avoid CI taking too long, including eagle speculative UTs, which made CI unable to take care of the eagle function. I added it(`test_eagle_correctness.py`) back in this PR 3. Because of the reason mentioned in 2, the current version of Eagle has a problem. I located and fixed this problem. It was because vllm's `draft_model_runner.py` was changed and vllm-ascend was not synchronized in time. 4. Currently, the UTs of v0 and v1 are mixed in the spec_decode directory. I split them into two directories: spec_decode_v0 and spec_decode_v1. 5. i found `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_include_gpu_probs_tensor` and `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_should_modify_greedy_probs_inplace` have changed in vllm, so i remove it in this pr. ### Does this PR introduce _any_ user-facing change? This PR fixes the functions of ngram and eagle spec decode in the v0 engine ### How was this patch tested? tested by CI Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent 7e6efbf commit 339d689

22 files changed

+384
-58
lines changed

.github/workflows/vllm_ascend_test_long_term.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,16 @@ jobs:
9797
- name: Run vllm-project/vllm-ascend long term test
9898
run: |
9999
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
100-
# spec decode test
101-
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
100+
# v0 spec decode test
101+
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process
102+
pytest -sv tests/e2e/long_term/spec_decode_v0 --ignore=tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py
103+
# v1 spec decode test
104+
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v1/test_v1_mtp_correctness.py
102105
# TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed
103-
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py
104-
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
105-
pytest -sv tests/e2e/long_term/spec_decode --ignore=tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
106+
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v1/test_v1_spec_decode.py
107+
# accuracy test single card
106108
pytest -sv tests/e2e/long_term/test_accuracy.py
107109
else
110+
# accuracy test multi card
108111
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py
109112
fi
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_eagle_correctness.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
"""This docstring details important information on the testing methodology.
20+
21+
Most of the tests rely on "greedy equality", where we expect the output of
22+
speculative decoding on a sequence to exactly match the output of normal non-
23+
speculative decoding.
24+
25+
Since speculative decoding with rejection sampling guarantees that the output
26+
distribution matches the target model's output distribution (up to hardware
27+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
28+
equality.
29+
30+
However, we still need to verify below scenario could be passed:
31+
* Batch size 1 greedy equality
32+
* Batch size >1 greedy equality
33+
* Test greedy equality under preemption
34+
* Test greedy equality under various number of speculative tokens.
35+
36+
With those tests, we can say at least, EAGLE would not break the
37+
correctness for the target model outputs.
38+
"""
39+
40+
import pytest
41+
42+
from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \
43+
run_equality_correctness_test
44+
45+
# main model
46+
MAIN_MODEL = "JackFram/llama-68m"
47+
48+
# speculative model
49+
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
50+
51+
# max. number of speculative tokens: this corresponds to
52+
# num_heads in the config.json of the speculator model.
53+
MAX_SPEC_TOKENS = 4
54+
55+
# precision
56+
# TODO The vLLM here uses float32, but some op on the vllm-ascend
57+
# do not support float32, such as ROPE, When it is fixed, it is
58+
# recommended to change this to float32.
59+
PRECISION = "float16"
60+
61+
62+
@pytest.mark.parametrize(
63+
"common_llm_kwargs",
64+
[{
65+
# Skip cuda graph recording for fast test.
66+
"enforce_eager": True,
67+
68+
# Print spec metrics.
69+
"disable_log_stats": False,
70+
71+
# Precision
72+
"dtype": PRECISION,
73+
74+
# Main model
75+
"model_name": MAIN_MODEL,
76+
}])
77+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
78+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
79+
@pytest.mark.parametrize("test_llm_kwargs", [
80+
{
81+
"speculative_config": {
82+
"model": SPEC_MODEL,
83+
"num_speculative_tokens": MAX_SPEC_TOKENS,
84+
},
85+
},
86+
])
87+
@pytest.mark.parametrize("output_len", [
88+
128,
89+
])
90+
@pytest.mark.parametrize("batch_size", [1, 32])
91+
@pytest.mark.parametrize("seed", [1])
92+
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
93+
per_test_common_llm_kwargs,
94+
baseline_llm_kwargs, test_llm_kwargs,
95+
batch_size: int, output_len: int,
96+
seed: int):
97+
98+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
99+
per_test_common_llm_kwargs,
100+
baseline_llm_kwargs, test_llm_kwargs,
101+
batch_size, output_len, seed)
102+
103+
104+
@pytest.mark.parametrize(
105+
"common_llm_kwargs",
106+
[{
107+
# Skip cuda graph recording for fast test.
108+
"enforce_eager": True,
109+
110+
# Print spec metrics.
111+
"disable_log_stats": False,
112+
113+
# Precision
114+
"dtype": PRECISION,
115+
116+
# Main model
117+
"model_name": MAIN_MODEL,
118+
}])
119+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
120+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
121+
@pytest.mark.parametrize("test_llm_kwargs", [{
122+
"speculative_config": {
123+
"model": SPEC_MODEL,
124+
"num_speculative_tokens": MAX_SPEC_TOKENS,
125+
"disable_logprobs": False,
126+
},
127+
}, {
128+
"speculative_config": {
129+
"model": SPEC_MODEL,
130+
"num_speculative_tokens": MAX_SPEC_TOKENS,
131+
"disable_logprobs": True,
132+
},
133+
}])
134+
@pytest.mark.parametrize("output_len", [
135+
128,
136+
])
137+
@pytest.mark.parametrize("batch_size", [8])
138+
@pytest.mark.parametrize("seed", [1])
139+
@pytest.mark.parametrize("logprobs", [1, 6])
140+
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
141+
per_test_common_llm_kwargs,
142+
baseline_llm_kwargs, test_llm_kwargs,
143+
batch_size: int, output_len: int, seed: int,
144+
logprobs: int):
145+
146+
run_equality_correctness_test(
147+
vllm_runner,
148+
common_llm_kwargs,
149+
per_test_common_llm_kwargs,
150+
baseline_llm_kwargs,
151+
test_llm_kwargs,
152+
batch_size,
153+
output_len,
154+
seed,
155+
logprobs=logprobs,
156+
prompt_logprobs=logprobs,
157+
disable_logprobs=test_llm_kwargs["speculative_config"]
158+
["disable_logprobs"])
159+
160+
161+
@pytest.mark.skipif(True, reason="Open it when graph mode ready.")
162+
@pytest.mark.parametrize(
163+
"common_llm_kwargs",
164+
[{
165+
"enforce_eager": False,
166+
167+
# Print spec metrics.
168+
"disable_log_stats": False,
169+
170+
# Precision
171+
"dtype": PRECISION,
172+
173+
# Main model
174+
"model_name": MAIN_MODEL,
175+
}])
176+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
177+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
178+
@pytest.mark.parametrize("test_llm_kwargs", [
179+
{
180+
"speculative_config": {
181+
"model": SPEC_MODEL,
182+
"num_speculative_tokens": MAX_SPEC_TOKENS,
183+
},
184+
},
185+
])
186+
@pytest.mark.parametrize("output_len", [
187+
128,
188+
])
189+
@pytest.mark.parametrize("batch_size", [1, 32])
190+
@pytest.mark.parametrize("seed", [1])
191+
def test_eagle_e2e_greedy_correctness_cuda_graph(
192+
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
193+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
194+
seed: int):
195+
"""Verify greedy equality with cuda graph enabled and different
196+
batch sizes."""
197+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
198+
per_test_common_llm_kwargs,
199+
baseline_llm_kwargs, test_llm_kwargs,
200+
batch_size, output_len, seed)
201+
202+
203+
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
204+
@pytest.mark.parametrize(
205+
"common_llm_kwargs",
206+
[{
207+
"block_size": 8,
208+
# 2 for small prompt, 256//8 for generated.
209+
"num_gpu_blocks_override": 2 + 256 // 8,
210+
"max_model_len": (2 + 256 // 8) * 8,
211+
212+
# Skip cuda graph recording for fast test.
213+
"enforce_eager": True,
214+
215+
# Precision
216+
"dtype": PRECISION,
217+
218+
# Main model
219+
"model_name": MAIN_MODEL,
220+
}])
221+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
222+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
223+
@pytest.mark.parametrize("test_llm_kwargs", [
224+
{
225+
"speculative_config": {
226+
"model": SPEC_MODEL,
227+
"num_speculative_tokens": MAX_SPEC_TOKENS,
228+
},
229+
},
230+
])
231+
@pytest.mark.parametrize(
232+
"output_len",
233+
[
234+
# Use small output len for fast test.
235+
128,
236+
])
237+
@pytest.mark.parametrize("batch_size", [4])
238+
@pytest.mark.parametrize("seed", [1])
239+
def test_eagle_e2e_greedy_correctness_with_preemption(
240+
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
241+
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
242+
seed: int):
243+
"""Verify greedy equality, even when some sequences are preempted mid-
244+
generation.
245+
"""
246+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
247+
per_test_common_llm_kwargs,
248+
baseline_llm_kwargs, test_llm_kwargs,
249+
batch_size, output_len, seed)
250+
251+
252+
@pytest.mark.parametrize(
253+
"common_llm_kwargs",
254+
[{
255+
# Skip cuda graph recording for fast test.
256+
"enforce_eager": True,
257+
258+
# Precision
259+
"dtype": PRECISION,
260+
261+
# Main model
262+
"model_name": MAIN_MODEL,
263+
}])
264+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
265+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
266+
@pytest.mark.parametrize(
267+
"test_llm_kwargs",
268+
[
269+
{
270+
"speculative_config": {
271+
"model": SPEC_MODEL,
272+
"num_speculative_tokens": k,
273+
},
274+
}
275+
# Try a range of num. speculative tokens
276+
for k in range(1, 1 + MAX_SPEC_TOKENS)
277+
])
278+
@pytest.mark.parametrize("batch_size", [2])
279+
@pytest.mark.parametrize(
280+
"output_len",
281+
[
282+
# Use smaller output len for fast test.
283+
32,
284+
])
285+
@pytest.mark.parametrize("seed", [1])
286+
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
287+
per_test_common_llm_kwargs, baseline_llm_kwargs,
288+
test_llm_kwargs, batch_size: int, output_len: int,
289+
seed: int):
290+
"""Verify that eagle speculative decoding produces exact equality
291+
to without spec decode with different values of num_speculative_tokens.
292+
"""
293+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
294+
per_test_common_llm_kwargs,
295+
baseline_llm_kwargs, test_llm_kwargs,
296+
batch_size, output_len, seed)
297+
298+
299+
@pytest.mark.parametrize(
300+
"common_llm_kwargs",
301+
[{
302+
# Skip cuda graph recording for fast test.
303+
"enforce_eager": True,
304+
305+
# Precision
306+
"dtype": PRECISION,
307+
308+
# Main model
309+
"model_name": MAIN_MODEL,
310+
}])
311+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
312+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
313+
@pytest.mark.parametrize("test_llm_kwargs", [{
314+
"speculative_config": {
315+
"model": SPEC_MODEL,
316+
"num_speculative_tokens": MAX_SPEC_TOKENS,
317+
"disable_by_batch_size": 4,
318+
},
319+
}])
320+
@pytest.mark.parametrize("batch_size", [1, 5])
321+
@pytest.mark.parametrize(
322+
"output_len",
323+
[
324+
# Use smaller output len for fast test.
325+
32,
326+
])
327+
@pytest.mark.parametrize("seed", [1])
328+
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
329+
per_test_common_llm_kwargs, baseline_llm_kwargs,
330+
test_llm_kwargs, batch_size: int, output_len: int,
331+
seed: int):
332+
"""Verify that eagle speculative decoding produces exact equality
333+
to without spec decode when speculation is disabled for large
334+
batch sizes.
335+
"""
336+
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
337+
per_test_common_llm_kwargs,
338+
baseline_llm_kwargs, test_llm_kwargs,
339+
batch_size, output_len, seed)
340+
341+
342+
if __name__ == "__main__":
343+
import pytest
344+
pytest.main([__file__])

tests/e2e/long_term/spec_decode/e2e/test_medusa_correctness.py renamed to tests/e2e/long_term/spec_decode_v0/e2e/test_medusa_correctness.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141

4242
import pytest
4343

44-
from tests.e2e.long_term.spec_decode.e2e.conftest import \
44+
from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \
4545
run_equality_correctness_test
46-
from tests.e2e.long_term.spec_decode.utils import maybe_enable_chunked_prefill
46+
from tests.e2e.long_term.spec_decode_v0.utils import \
47+
maybe_enable_chunked_prefill
4748

4849
# main model
4950
# lmsys/vicuna-7b-v1.3 was to be used but it's causing

tests/e2e/long_term/spec_decode/e2e/test_mlp_correctness.py renamed to tests/e2e/long_term/spec_decode_v0/e2e/test_mlp_correctness.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
from vllm.model_executor.layers.vocab_parallel_embedding import \
4242
pad_vocab_size # noqa: F401
4343

44-
from tests.e2e.long_term.spec_decode.e2e.conftest import \
44+
from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \
4545
run_equality_correctness_test
46-
from tests.e2e.long_term.spec_decode.utils import maybe_enable_chunked_prefill
46+
from tests.e2e.long_term.spec_decode_v0.utils import \
47+
maybe_enable_chunked_prefill
4748

4849
# main model
4950
MAIN_MODEL = "JackFram/llama-160m"

0 commit comments

Comments
 (0)