1
1
# SPDX-License-Identifier: Apache-2.0
2
2
from __future__ import annotations
3
3
4
+ import os
4
5
import random
5
6
from typing import Any
6
7
7
8
import pytest
9
+
8
10
from vllm import LLM , SamplingParams
9
11
12
+ os .environ ["VLLM_USE_MODELSCOPE" ] = "True"
13
+
10
14
11
15
@pytest .fixture
12
16
def test_prompts ():
@@ -43,18 +47,20 @@ def test_prompts():
43
47
44
48
@pytest .fixture
45
49
def sampling_config ():
46
- # Only support greedy for now
47
50
return SamplingParams (temperature = 0 , max_tokens = 10 , ignore_eos = False )
48
51
49
52
50
53
@pytest .fixture
51
54
def model_name ():
52
- return "meta-llama/Meta- Llama-3-8B-Instruct"
55
+ return "LLM-Research/ Llama-3.1 -8B-Instruct"
53
56
54
57
55
- @pytest .fixture
56
58
def eagle_model_name ():
57
- return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
59
+ return "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"
60
+
61
+
62
+ def eagle3_model_name ():
63
+ return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"
58
64
59
65
60
66
def test_ngram_correctness (
@@ -97,37 +103,42 @@ def test_ngram_correctness(
97
103
98
104
# Heuristic: expect at least 70% of the prompts to match exactly
99
105
# Upon failure, inspect the outputs to check for inaccuracy.
100
- assert matches > int (0.6 * len (ref_outputs ))
106
+ assert matches > int (0.7 * len (ref_outputs ))
101
107
del spec_llm
102
108
103
109
110
+ @pytest .mark .parametrize ("use_eagle3" , [False , True ], ids = ["eagle" , "eagle3" ])
104
111
def test_eagle_correctness (
105
112
monkeypatch : pytest .MonkeyPatch ,
106
113
test_prompts : list [list [dict [str , Any ]]],
107
114
sampling_config : SamplingParams ,
108
115
model_name : str ,
109
- eagle_model_name : str ,
116
+ use_eagle3 : bool ,
110
117
):
111
- pytest .skip ("Not current support for the test." )
112
118
'''
113
119
Compare the outputs of a original LLM and a speculative LLM
114
120
should be the same when using eagle speculative decoding.
115
121
'''
122
+ pytest .skip ("Not current support for the test." )
116
123
with monkeypatch .context () as m :
117
124
m .setenv ("VLLM_USE_V1" , "1" )
118
125
119
- ref_llm = LLM (model = model_name , max_model_len = 1024 )
126
+ ref_llm = LLM (model = model_name , max_model_len = 2048 )
120
127
ref_outputs = ref_llm .chat (test_prompts , sampling_config )
121
128
del ref_llm
122
129
130
+ spec_model_name = eagle3_model_name (
131
+ ) if use_eagle3 else eagle_model_name ()
123
132
spec_llm = LLM (
124
133
model = model_name ,
134
+ trust_remote_code = True ,
125
135
speculative_config = {
126
- "method" : "eagle" ,
127
- "model" : eagle_model_name ,
136
+ "method" : "eagle3" if use_eagle3 else " eagle" ,
137
+ "model" : spec_model_name ,
128
138
"num_speculative_tokens" : 3 ,
139
+ "max_model_len" : 2048 ,
129
140
},
130
- max_model_len = 1024 ,
141
+ max_model_len = 2048 ,
131
142
)
132
143
spec_outputs = spec_llm .chat (test_prompts , sampling_config )
133
144
matches = 0
@@ -140,7 +151,7 @@ def test_eagle_correctness(
140
151
print (f"ref_output: { ref_output .outputs [0 ].text } " )
141
152
print (f"spec_output: { spec_output .outputs [0 ].text } " )
142
153
143
- # Heuristic: expect at least 70 % of the prompts to match exactly
154
+ # Heuristic: expect at least 66 % of the prompts to match exactly
144
155
# Upon failure, inspect the outputs to check for inaccuracy.
145
- assert matches > int (0.7 * len (ref_outputs ))
156
+ assert matches > int (0.66 * len (ref_outputs ))
146
157
del spec_llm
0 commit comments