6
6
from typing import Any
7
7
8
8
import pytest
9
+ import torch
9
10
10
11
from vllm import LLM , SamplingParams
12
+ from vllm .distributed import cleanup_dist_env_and_memory
11
13
12
14
13
15
@pytest .fixture
@@ -53,14 +55,6 @@ def model_name():
53
55
return "meta-llama/Llama-3.1-8B-Instruct"
54
56
55
57
56
- def eagle_model_name ():
57
- return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
58
-
59
-
60
- def eagle3_model_name ():
61
- return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
62
-
63
-
64
58
def test_ngram_correctness (
65
59
monkeypatch : pytest .MonkeyPatch ,
66
60
test_prompts : list [list [dict [str , Any ]]],
@@ -77,6 +71,8 @@ def test_ngram_correctness(
77
71
ref_llm = LLM (model = model_name , max_model_len = 1024 )
78
72
ref_outputs = ref_llm .chat (test_prompts , sampling_config )
79
73
del ref_llm
74
+ torch .cuda .empty_cache ()
75
+ cleanup_dist_env_and_memory ()
80
76
81
77
spec_llm = LLM (
82
78
model = model_name ,
@@ -103,34 +99,48 @@ def test_ngram_correctness(
103
99
# Upon failure, inspect the outputs to check for inaccuracy.
104
100
assert matches > int (0.7 * len (ref_outputs ))
105
101
del spec_llm
106
-
107
-
108
- @pytest .mark .parametrize ("use_eagle3" , [False , True ], ids = ["eagle" , "eagle3" ])
102
+ torch .cuda .empty_cache ()
103
+ cleanup_dist_env_and_memory ()
104
+
105
+
106
+ @pytest .mark .parametrize ("model_setup" , [
107
+ ("eagle" , "meta-llama/Llama-3.1-8B-Instruct" ,
108
+ "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" , 1 ),
109
+ ("eagle3" , "meta-llama/Llama-3.1-8B-Instruct" ,
110
+ "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" , 1 ),
111
+ ("eagle" , "/home/zhiweiz/local/models/scout_base_HF_20250605_201140" ,
112
+ "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct" , 4 ),
113
+ ],
114
+ ids = ["llama3_eagle" , "llama3_eagle3" , "llama4_eagle" ])
109
115
def test_eagle_correctness (
110
116
monkeypatch : pytest .MonkeyPatch ,
111
117
test_prompts : list [list [dict [str , Any ]]],
112
118
sampling_config : SamplingParams ,
113
- model_name : str ,
114
- use_eagle3 : bool ,
119
+ model_setup : tuple [str , str , str , int ],
115
120
):
116
121
'''
117
122
Compare the outputs of a original LLM and a speculative LLM
118
123
should be the same when using eagle speculative decoding.
124
+ model_setup: (method, model_name, eagle_model_name, tp_size)
119
125
'''
120
126
with monkeypatch .context () as m :
121
127
m .setenv ("VLLM_USE_V1" , "1" )
128
+ method , model_name , spec_model_name , tp_size = model_setup
122
129
123
- ref_llm = LLM (model = model_name , max_model_len = 2048 )
130
+ ref_llm = LLM (model = model_name ,
131
+ max_model_len = 2048 ,
132
+ tensor_parallel_size = tp_size )
124
133
ref_outputs = ref_llm .chat (test_prompts , sampling_config )
125
134
del ref_llm
135
+ torch .cuda .empty_cache ()
136
+ cleanup_dist_env_and_memory ()
126
137
127
- spec_model_name = eagle3_model_name (
128
- ) if use_eagle3 else eagle_model_name ()
129
138
spec_llm = LLM (
130
139
model = model_name ,
131
140
trust_remote_code = True ,
141
+ tensor_parallel_size = tp_size ,
132
142
speculative_config = {
133
- "method" : "eagle3" if use_eagle3 else "eagle" ,
143
+ "method" : method ,
134
144
"model" : spec_model_name ,
135
145
"num_speculative_tokens" : 3 ,
136
146
"max_model_len" : 2048 ,
@@ -152,3 +162,5 @@ def test_eagle_correctness(
152
162
# Upon failure, inspect the outputs to check for inaccuracy.
153
163
assert matches > int (0.66 * len (ref_outputs ))
154
164
del spec_llm
165
+ torch .cuda .empty_cache ()
166
+ cleanup_dist_env_and_memory ()
0 commit comments