Skip to content

Commit edeadde

Browse files
mengwei805XWFAlonexuan63
authored
[2/N][CI/UT] enable spec_decode UT (#474)
### What this PR does / why we need it? Added e2e test UT with spec_decode feature, Remaining issues: 1. Case of graph mode; when vllm-ascend supports; 2. Case of preemption scenario; when bugfix; 3. Case related to quantization; when vllm-ascend supports; 5. test_multistep_correctness.py; when bugfix in Both chunked prefill and spec decode are enabled; 6. test_mtp_correctness.py; when bf16 weights ready; 7. test_eagle_correctness.py; when enable get Meta/llama weights by modelscope; ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? tested by CI Signed-off-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: XWFAlone <xueweifei2@huawei.com> Co-authored-by: xuan63 <huangshixuan@huawei.com>
1 parent 2b765dc commit edeadde

15 files changed

+3353
-2
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ modelscope<1.23.0
33
pytest >= 6.0
44
pytest-asyncio
55
pybind11
6+
ray

tests/__init__.py

Whitespace-only changes.

tests/spec_decode/e2e/__init__.py

Whitespace-only changes.

tests/spec_decode/e2e/conftest.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/conftest.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
20+
from itertools import cycle
21+
from typing import List, Optional, Sequence, Tuple, Union
22+
23+
import pytest
24+
import torch
25+
from vllm import LLM, SamplingParams
26+
from vllm.distributed import cleanup_dist_env_and_memory
27+
from vllm.model_executor.utils import set_random_seed
28+
from vllm.sequence import PromptLogprobs, SampleLogprobs
29+
30+
from ...model_utils import (TokensTextLogprobs,
31+
TokensTextLogprobsPromptLogprobs,
32+
check_logprobs_close, check_outputs_equal)
33+
from ...utils import RemoteOpenAIServer
34+
35+
PROMPTS = [
36+
"Hello, my name is",
37+
"The president of the United States is",
38+
"The capital of France is",
39+
"The future of AI is",
40+
"San Francisco is know for its",
41+
"Facebook was created in 2004 by",
42+
"Curious George is a",
43+
"Python 3.11 brings improvements to its",
44+
]
45+
46+
47+
@pytest.fixture
48+
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
49+
test_llm_kwargs, seed):
50+
51+
def generate():
52+
kwargs = {
53+
**common_llm_kwargs,
54+
**per_test_common_llm_kwargs,
55+
**test_llm_kwargs,
56+
}
57+
58+
llm = LLM(**kwargs)
59+
60+
if seed is not None:
61+
set_random_seed(seed)
62+
63+
yield llm
64+
65+
del llm
66+
cleanup_dist_env_and_memory()
67+
68+
return generate
69+
70+
71+
def maybe_assert_ngram_worker(llm):
72+
# Verify the proposer worker is ngram if ngram is specified.
73+
if (llm.llm_engine.speculative_config is not None
74+
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
75+
from vllm.spec_decode.ngram_worker import NGramWorker
76+
assert isinstance(
77+
llm.llm_engine.model_executor.driver_worker.proposer_worker,
78+
NGramWorker)
79+
80+
81+
def get_output_from_llm_generator(
82+
llm_generator, prompts,
83+
sampling_params) -> Tuple[List[str], List[List[int]], float]:
84+
tokens: List[str] = []
85+
token_ids: List[List[int]] = []
86+
acceptance_rate: float = -1.0
87+
for llm in llm_generator():
88+
maybe_assert_ngram_worker(llm)
89+
90+
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
91+
92+
token_ids = [output.outputs[0].token_ids for output in outputs]
93+
tokens = [output.outputs[0].text for output in outputs]
94+
95+
# Fetch acceptance rate if logging is enabled.
96+
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
97+
stat_logger = stat_loggers["prometheus"]
98+
acceptance_rate = (stat_logger.metrics.
99+
gauge_spec_decode_draft_acceptance_rate.labels(
100+
**stat_logger.labels)._value.get())
101+
del llm
102+
103+
return tokens, token_ids, acceptance_rate
104+
105+
106+
def check_logprobs_correctness(
107+
spec_outputs: Sequence[Union[TokensTextLogprobs,
108+
TokensTextLogprobsPromptLogprobs]],
109+
baseline_outputs: Sequence[Union[TokensTextLogprobs,
110+
TokensTextLogprobsPromptLogprobs]],
111+
disable_logprobs: bool = False,
112+
):
113+
"""Compare sampled and prompt logprobs between baseline and spec decoding
114+
"""
115+
if not disable_logprobs:
116+
return check_logprobs_close(
117+
outputs_0_lst=baseline_outputs,
118+
outputs_1_lst=spec_outputs,
119+
name_0="org",
120+
name_1="sd",
121+
)
122+
123+
# Check correctness when disable_logprobs == True
124+
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
125+
# Check generated token logprobs.
126+
spec_logprobs = spec_output[2]
127+
baseline_logprobs = baseline_output[2]
128+
_check_logprobs_when_output_disabled(spec_logprobs,
129+
baseline_logprobs,
130+
is_prompt_logprobs=False)
131+
132+
# Check prompt logprobs too, if they exist
133+
if len(baseline_output) == 4:
134+
assert len(spec_output) == 4
135+
spec_prompt_logprobs = spec_output[3]
136+
baseline_prompt_logprobs = baseline_output[3]
137+
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
138+
baseline_prompt_logprobs,
139+
is_prompt_logprobs=True)
140+
141+
142+
def _check_logprobs_when_output_disabled(
143+
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
144+
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
145+
is_prompt_logprobs: bool = False,
146+
):
147+
# Prompt logprobs are optional
148+
if is_prompt_logprobs and baseline_logprobs is None:
149+
assert spec_logprobs is None
150+
return
151+
152+
assert spec_logprobs is not None
153+
assert baseline_logprobs is not None
154+
assert len(spec_logprobs) == len(baseline_logprobs)
155+
156+
# For each generated position of the sequence.
157+
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
158+
zip(spec_logprobs, baseline_logprobs)):
159+
160+
# First prompt logprob is expected to be None
161+
if is_prompt_logprobs and baseline_pos_logprobs is None:
162+
assert spec_pos_logprobs is None
163+
assert pos == 0
164+
continue
165+
166+
assert spec_pos_logprobs is not None
167+
assert baseline_pos_logprobs is not None
168+
169+
# When disabled, the 1 logprob is returned with dummy values for the
170+
# score and rank, but the token id should match the baseline model
171+
assert len(spec_pos_logprobs) == 1
172+
(spec_pos_logprob_token_id,
173+
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
174+
assert spec_pos_logprob.rank == -1
175+
assert spec_pos_logprob.logprob == 0.0
176+
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
177+
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
178+
assert spec_pos_logprob_token_id in baseline_pos_logprobs
179+
180+
181+
def run_equality_correctness_test(
182+
vllm_runner,
183+
common_llm_kwargs,
184+
per_test_common_llm_kwargs,
185+
baseline_llm_kwargs,
186+
test_llm_kwargs,
187+
batch_size: int,
188+
max_output_len: int,
189+
seed: Optional[int] = 0,
190+
temperature: float = 0.0,
191+
disable_seed: bool = False,
192+
ignore_eos: bool = True,
193+
ensure_all_accepted: bool = False,
194+
expected_acceptance_rate: Optional[float] = None,
195+
logprobs: Optional[int] = None,
196+
prompt_logprobs: Optional[int] = None,
197+
disable_logprobs: bool = False):
198+
199+
org_args = {
200+
**common_llm_kwargs,
201+
**per_test_common_llm_kwargs,
202+
**baseline_llm_kwargs,
203+
}
204+
205+
sd_args = {
206+
**common_llm_kwargs,
207+
**per_test_common_llm_kwargs,
208+
**test_llm_kwargs,
209+
}
210+
211+
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
212+
213+
if disable_seed:
214+
seed = None
215+
216+
sampling_params = SamplingParams(temperature=temperature,
217+
max_tokens=max_output_len,
218+
seed=seed,
219+
ignore_eos=ignore_eos,
220+
logprobs=logprobs,
221+
prompt_logprobs=prompt_logprobs)
222+
223+
with vllm_runner(**org_args) as vllm_model:
224+
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
225+
226+
with vllm_runner(**sd_args) as vllm_model:
227+
if ensure_all_accepted or expected_acceptance_rate is not None:
228+
# Force log interval to be 0 to catch all metrics.
229+
stat_logger = vllm_model.model.llm_engine.stat_loggers[
230+
'prometheus']
231+
stat_logger.local_interval = -100
232+
233+
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
234+
235+
if ensure_all_accepted or expected_acceptance_rate is not None:
236+
acceptance_rate = (stat_logger.metrics.
237+
gauge_spec_decode_draft_acceptance_rate.labels(
238+
**stat_logger.labels)._value.get())
239+
240+
if ensure_all_accepted:
241+
assert True
242+
# FIXME: ci fails to log acceptance rate.
243+
# It works locally.
244+
# assert acceptance_rate == 1.0
245+
246+
if expected_acceptance_rate is not None:
247+
assert acceptance_rate >= expected_acceptance_rate - 1e-2
248+
249+
# Only pass token entries, not the logprobs
250+
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
251+
outputs_1_lst=[out[0:2] for out in sd_outputs],
252+
name_0="org",
253+
name_1="sd")
254+
255+
# Check logprobs if requested
256+
if logprobs is not None or prompt_logprobs is not None:
257+
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
258+
259+
260+
def run_equality_correctness_test_tp(model,
261+
common_llm_kwargs,
262+
per_test_common_llm_kwargs,
263+
baseline_llm_kwargs,
264+
test_llm_kwargs,
265+
batch_size: int,
266+
max_output_len: int,
267+
seed: int = 0,
268+
temperature: float = 0.0,
269+
logprobs: Optional[int] = None):
270+
"""Helper method that compares the outputs of both the baseline LLM and
271+
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
272+
the same when temperature is zero.
273+
"""
274+
arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
275+
arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
276+
env1 = env2 = None
277+
278+
max_wait_seconds = 240
279+
results = []
280+
281+
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
282+
for args, env in ((arg1, env1), (arg2, env2)):
283+
with RemoteOpenAIServer(model,
284+
args,
285+
env_dict=env,
286+
max_wait_seconds=max_wait_seconds) as server:
287+
client = server.get_client()
288+
289+
completion = client.completions.create(model=model,
290+
prompt=prompts,
291+
max_tokens=max_output_len,
292+
seed=seed,
293+
temperature=temperature,
294+
logprobs=logprobs)
295+
296+
results.append({
297+
"test":
298+
"seeded_sampling",
299+
"text": [choice.text for choice in completion.choices],
300+
"logprobs": [choice.logprobs for choice in completion.choices],
301+
"finish_reason":
302+
[choice.finish_reason for choice in completion.choices],
303+
"usage":
304+
completion.usage,
305+
})
306+
307+
n = len(results) // 2
308+
arg1_results = results[:n]
309+
arg2_results = results[n:]
310+
# Separate logprobs to avoid asserting exact equality.
311+
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
312+
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
313+
314+
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
315+
assert arg1_result == arg2_result, (
316+
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
317+
f"{arg1_result=} != {arg2_result=}")
318+
if logprobs:
319+
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
320+
for l1, l2 in zip(logs1, logs2):
321+
assert l1.tokens == l2.tokens
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_compatibility.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
20+
import pytest
21+
from vllm import SamplingParams
22+
23+
from .conftest import get_output_from_llm_generator
24+
25+
26+
@pytest.mark.parametrize("common_llm_kwargs", [{
27+
"model": "meta-llama/Llama-3.2-1B-Instruct",
28+
"speculative_model": "JackFram/llama-68m",
29+
"num_speculative_tokens": 5,
30+
}])
31+
@pytest.mark.parametrize(
32+
"per_test_common_llm_kwargs",
33+
[
34+
{
35+
# Speculative max model len > overridden max model len should raise.
36+
"max_model_len": 128,
37+
"speculative_max_model_len": 129,
38+
},
39+
{
40+
# Speculative max model len > draft max model len should raise.
41+
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
42+
"speculative_max_model_len": 2048 + 1,
43+
},
44+
{
45+
# Speculative max model len > target max model len should raise.
46+
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
47+
"speculative_max_model_len": 131072 + 1,
48+
},
49+
])
50+
@pytest.mark.parametrize("test_llm_kwargs", [{}])
51+
@pytest.mark.parametrize("seed", [1])
52+
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
53+
"""Verify that speculative decoding validates speculative_max_model_len.
54+
"""
55+
output_len = 128
56+
temperature = 0.0
57+
58+
prompts = [
59+
"Hello, my name is",
60+
]
61+
62+
sampling_params = SamplingParams(
63+
max_tokens=output_len,
64+
ignore_eos=True,
65+
temperature=temperature,
66+
)
67+
68+
with pytest.raises(ValueError, match="cannot be larger than"):
69+
get_output_from_llm_generator(test_llm_generator, prompts,
70+
sampling_params)

0 commit comments

Comments
 (0)