Skip to content

Commit 3797db8

Browse files
authored
Update test_spec_decode.py
1 parent 8f7ffc4 commit 3797db8

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

tests/singlecard/spec_decode/e2e/test_spec_decode.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from __future__ import annotations
33

4+
import os
45
import random
56
from typing import Any
67

78
import pytest
9+
810
from vllm import LLM, SamplingParams
911

12+
os.environ["VLLM_USE_MODELSCOPE"] = "True"
13+
1014

1115
@pytest.fixture
1216
def test_prompts():
@@ -43,18 +47,20 @@ def test_prompts():
4347

4448
@pytest.fixture
4549
def sampling_config():
46-
# Only support greedy for now
4750
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
4851

4952

5053
@pytest.fixture
5154
def model_name():
52-
return "meta-llama/Meta-Llama-3-8B-Instruct"
55+
return "LLM-Research/Llama-3.1-8B-Instruct"
5356

5457

55-
@pytest.fixture
5658
def eagle_model_name():
57-
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
59+
return "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"
60+
61+
62+
def eagle3_model_name():
63+
return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"
5864

5965

6066
def test_ngram_correctness(
@@ -97,37 +103,42 @@ def test_ngram_correctness(
97103

98104
# Heuristic: expect at least 70% of the prompts to match exactly
99105
# Upon failure, inspect the outputs to check for inaccuracy.
100-
assert matches > int(0.6 * len(ref_outputs))
106+
assert matches > int(0.7 * len(ref_outputs))
101107
del spec_llm
102108

103109

110+
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
104111
def test_eagle_correctness(
105112
monkeypatch: pytest.MonkeyPatch,
106113
test_prompts: list[list[dict[str, Any]]],
107114
sampling_config: SamplingParams,
108115
model_name: str,
109-
eagle_model_name: str,
116+
use_eagle3: bool,
110117
):
111-
pytest.skip("Not current support for the test.")
112118
'''
113119
Compare the outputs of a original LLM and a speculative LLM
114120
should be the same when using eagle speculative decoding.
115121
'''
122+
pytest.skip("Not current support for the test.")
116123
with monkeypatch.context() as m:
117124
m.setenv("VLLM_USE_V1", "1")
118125

119-
ref_llm = LLM(model=model_name, max_model_len=1024)
126+
ref_llm = LLM(model=model_name, max_model_len=2048)
120127
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
121128
del ref_llm
122129

130+
spec_model_name = eagle3_model_name(
131+
) if use_eagle3 else eagle_model_name()
123132
spec_llm = LLM(
124133
model=model_name,
134+
trust_remote_code=True,
125135
speculative_config={
126-
"method": "eagle",
127-
"model": eagle_model_name,
136+
"method": "eagle3" if use_eagle3 else "eagle",
137+
"model": spec_model_name,
128138
"num_speculative_tokens": 3,
139+
"max_model_len": 2048,
129140
},
130-
max_model_len=1024,
141+
max_model_len=2048,
131142
)
132143
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
133144
matches = 0
@@ -140,7 +151,7 @@ def test_eagle_correctness(
140151
print(f"ref_output: {ref_output.outputs[0].text}")
141152
print(f"spec_output: {spec_output.outputs[0].text}")
142153

143-
# Heuristic: expect at least 70% of the prompts to match exactly
154+
# Heuristic: expect at least 66% of the prompts to match exactly
144155
# Upon failure, inspect the outputs to check for inaccuracy.
145-
assert matches > int(0.7 * len(ref_outputs))
156+
assert matches > int(0.66 * len(ref_outputs))
146157
del spec_llm

0 commit comments

Comments
 (0)