1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ # Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_eagle_correctness.py
5
+ # Copyright 2023 The vLLM team.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+ """This docstring details important information on the testing methodology.
20
+
21
+ Most of the tests rely on "greedy equality", where we expect the output of
22
+ speculative decoding on a sequence to exactly match the output of normal non-
23
+ speculative decoding.
24
+
25
+ Since speculative decoding with rejection sampling guarantees that the output
26
+ distribution matches the target model's output distribution (up to hardware
27
+ numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
28
+ equality.
29
+
30
+ However, we still need to verify below scenario could be passed:
31
+ * Batch size 1 greedy equality
32
+ * Batch size >1 greedy equality
33
+ * Test greedy equality under preemption
34
+ * Test greedy equality under various number of speculative tokens.
35
+
36
+ With those tests, we can say at least, EAGLE would not break the
37
+ correctness for the target model outputs.
38
+ """
39
+
40
+ import pytest
41
+
42
+ from tests .e2e .long_term .spec_decode_v0 .e2e .conftest import \
43
+ run_equality_correctness_test
44
+
45
+ # main model
46
+ MAIN_MODEL = "JackFram/llama-68m"
47
+
48
+ # speculative model
49
+ SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
50
+
51
+ # max. number of speculative tokens: this corresponds to
52
+ # num_heads in the config.json of the speculator model.
53
+ MAX_SPEC_TOKENS = 4
54
+
55
+ # precision
56
+ # TODO The vLLM here uses float32, but some op on the vllm-ascend
57
+ # do not support float32, such as ROPE, When it is fixed, it is
58
+ # recommended to change this to float32.
59
+ PRECISION = "float16"
60
+
61
+
62
+ @pytest .mark .parametrize (
63
+ "common_llm_kwargs" ,
64
+ [{
65
+ # Skip cuda graph recording for fast test.
66
+ "enforce_eager" : True ,
67
+
68
+ # Print spec metrics.
69
+ "disable_log_stats" : False ,
70
+
71
+ # Precision
72
+ "dtype" : PRECISION ,
73
+
74
+ # Main model
75
+ "model_name" : MAIN_MODEL ,
76
+ }])
77
+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
78
+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
79
+ @pytest .mark .parametrize ("test_llm_kwargs" , [
80
+ {
81
+ "speculative_config" : {
82
+ "model" : SPEC_MODEL ,
83
+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
84
+ },
85
+ },
86
+ ])
87
+ @pytest .mark .parametrize ("output_len" , [
88
+ 128 ,
89
+ ])
90
+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ])
91
+ @pytest .mark .parametrize ("seed" , [1 ])
92
+ def test_eagle_e2e_greedy_correctness (vllm_runner , common_llm_kwargs ,
93
+ per_test_common_llm_kwargs ,
94
+ baseline_llm_kwargs , test_llm_kwargs ,
95
+ batch_size : int , output_len : int ,
96
+ seed : int ):
97
+
98
+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
99
+ per_test_common_llm_kwargs ,
100
+ baseline_llm_kwargs , test_llm_kwargs ,
101
+ batch_size , output_len , seed )
102
+
103
+
104
+ @pytest .mark .parametrize (
105
+ "common_llm_kwargs" ,
106
+ [{
107
+ # Skip cuda graph recording for fast test.
108
+ "enforce_eager" : True ,
109
+
110
+ # Print spec metrics.
111
+ "disable_log_stats" : False ,
112
+
113
+ # Precision
114
+ "dtype" : PRECISION ,
115
+
116
+ # Main model
117
+ "model_name" : MAIN_MODEL ,
118
+ }])
119
+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
120
+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
121
+ @pytest .mark .parametrize ("test_llm_kwargs" , [{
122
+ "speculative_config" : {
123
+ "model" : SPEC_MODEL ,
124
+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
125
+ "disable_logprobs" : False ,
126
+ },
127
+ }, {
128
+ "speculative_config" : {
129
+ "model" : SPEC_MODEL ,
130
+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
131
+ "disable_logprobs" : True ,
132
+ },
133
+ }])
134
+ @pytest .mark .parametrize ("output_len" , [
135
+ 128 ,
136
+ ])
137
+ @pytest .mark .parametrize ("batch_size" , [8 ])
138
+ @pytest .mark .parametrize ("seed" , [1 ])
139
+ @pytest .mark .parametrize ("logprobs" , [1 , 6 ])
140
+ def test_eagle_e2e_greedy_logprobs (vllm_runner , common_llm_kwargs ,
141
+ per_test_common_llm_kwargs ,
142
+ baseline_llm_kwargs , test_llm_kwargs ,
143
+ batch_size : int , output_len : int , seed : int ,
144
+ logprobs : int ):
145
+
146
+ run_equality_correctness_test (
147
+ vllm_runner ,
148
+ common_llm_kwargs ,
149
+ per_test_common_llm_kwargs ,
150
+ baseline_llm_kwargs ,
151
+ test_llm_kwargs ,
152
+ batch_size ,
153
+ output_len ,
154
+ seed ,
155
+ logprobs = logprobs ,
156
+ prompt_logprobs = logprobs ,
157
+ disable_logprobs = test_llm_kwargs ["speculative_config" ]
158
+ ["disable_logprobs" ])
159
+
160
+
161
+ @pytest .mark .skipif (True , reason = "Open it when graph mode ready." )
162
+ @pytest .mark .parametrize (
163
+ "common_llm_kwargs" ,
164
+ [{
165
+ "enforce_eager" : False ,
166
+
167
+ # Print spec metrics.
168
+ "disable_log_stats" : False ,
169
+
170
+ # Precision
171
+ "dtype" : PRECISION ,
172
+
173
+ # Main model
174
+ "model_name" : MAIN_MODEL ,
175
+ }])
176
+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
177
+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
178
+ @pytest .mark .parametrize ("test_llm_kwargs" , [
179
+ {
180
+ "speculative_config" : {
181
+ "model" : SPEC_MODEL ,
182
+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
183
+ },
184
+ },
185
+ ])
186
+ @pytest .mark .parametrize ("output_len" , [
187
+ 128 ,
188
+ ])
189
+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ])
190
+ @pytest .mark .parametrize ("seed" , [1 ])
191
+ def test_eagle_e2e_greedy_correctness_cuda_graph (
192
+ vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
193
+ baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
194
+ seed : int ):
195
+ """Verify greedy equality with cuda graph enabled and different
196
+ batch sizes."""
197
+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
198
+ per_test_common_llm_kwargs ,
199
+ baseline_llm_kwargs , test_llm_kwargs ,
200
+ batch_size , output_len , seed )
201
+
202
+
203
+ @pytest .mark .skipif (True , reason = "Open it when preempt ready." )
204
+ @pytest .mark .parametrize (
205
+ "common_llm_kwargs" ,
206
+ [{
207
+ "block_size" : 8 ,
208
+ # 2 for small prompt, 256//8 for generated.
209
+ "num_gpu_blocks_override" : 2 + 256 // 8 ,
210
+ "max_model_len" : (2 + 256 // 8 ) * 8 ,
211
+
212
+ # Skip cuda graph recording for fast test.
213
+ "enforce_eager" : True ,
214
+
215
+ # Precision
216
+ "dtype" : PRECISION ,
217
+
218
+ # Main model
219
+ "model_name" : MAIN_MODEL ,
220
+ }])
221
+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
222
+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
223
+ @pytest .mark .parametrize ("test_llm_kwargs" , [
224
+ {
225
+ "speculative_config" : {
226
+ "model" : SPEC_MODEL ,
227
+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
228
+ },
229
+ },
230
+ ])
231
+ @pytest .mark .parametrize (
232
+ "output_len" ,
233
+ [
234
+ # Use small output len for fast test.
235
+ 128 ,
236
+ ])
237
+ @pytest .mark .parametrize ("batch_size" , [4 ])
238
+ @pytest .mark .parametrize ("seed" , [1 ])
239
+ def test_eagle_e2e_greedy_correctness_with_preemption (
240
+ vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
241
+ baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
242
+ seed : int ):
243
+ """Verify greedy equality, even when some sequences are preempted mid-
244
+ generation.
245
+ """
246
+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
247
+ per_test_common_llm_kwargs ,
248
+ baseline_llm_kwargs , test_llm_kwargs ,
249
+ batch_size , output_len , seed )
250
+
251
+
252
+ @pytest .mark .parametrize (
253
+ "common_llm_kwargs" ,
254
+ [{
255
+ # Skip cuda graph recording for fast test.
256
+ "enforce_eager" : True ,
257
+
258
+ # Precision
259
+ "dtype" : PRECISION ,
260
+
261
+ # Main model
262
+ "model_name" : MAIN_MODEL ,
263
+ }])
264
+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
265
+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
266
+ @pytest .mark .parametrize (
267
+ "test_llm_kwargs" ,
268
+ [
269
+ {
270
+ "speculative_config" : {
271
+ "model" : SPEC_MODEL ,
272
+ "num_speculative_tokens" : k ,
273
+ },
274
+ }
275
+ # Try a range of num. speculative tokens
276
+ for k in range (1 , 1 + MAX_SPEC_TOKENS )
277
+ ])
278
+ @pytest .mark .parametrize ("batch_size" , [2 ])
279
+ @pytest .mark .parametrize (
280
+ "output_len" ,
281
+ [
282
+ # Use smaller output len for fast test.
283
+ 32 ,
284
+ ])
285
+ @pytest .mark .parametrize ("seed" , [1 ])
286
+ def test_eagle_different_k (vllm_runner , common_llm_kwargs ,
287
+ per_test_common_llm_kwargs , baseline_llm_kwargs ,
288
+ test_llm_kwargs , batch_size : int , output_len : int ,
289
+ seed : int ):
290
+ """Verify that eagle speculative decoding produces exact equality
291
+ to without spec decode with different values of num_speculative_tokens.
292
+ """
293
+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
294
+ per_test_common_llm_kwargs ,
295
+ baseline_llm_kwargs , test_llm_kwargs ,
296
+ batch_size , output_len , seed )
297
+
298
+
299
+ @pytest .mark .parametrize (
300
+ "common_llm_kwargs" ,
301
+ [{
302
+ # Skip cuda graph recording for fast test.
303
+ "enforce_eager" : True ,
304
+
305
+ # Precision
306
+ "dtype" : PRECISION ,
307
+
308
+ # Main model
309
+ "model_name" : MAIN_MODEL ,
310
+ }])
311
+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
312
+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
313
+ @pytest .mark .parametrize ("test_llm_kwargs" , [{
314
+ "speculative_config" : {
315
+ "model" : SPEC_MODEL ,
316
+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
317
+ "disable_by_batch_size" : 4 ,
318
+ },
319
+ }])
320
+ @pytest .mark .parametrize ("batch_size" , [1 , 5 ])
321
+ @pytest .mark .parametrize (
322
+ "output_len" ,
323
+ [
324
+ # Use smaller output len for fast test.
325
+ 32 ,
326
+ ])
327
+ @pytest .mark .parametrize ("seed" , [1 ])
328
+ def test_eagle_disable_queue (vllm_runner , common_llm_kwargs ,
329
+ per_test_common_llm_kwargs , baseline_llm_kwargs ,
330
+ test_llm_kwargs , batch_size : int , output_len : int ,
331
+ seed : int ):
332
+ """Verify that eagle speculative decoding produces exact equality
333
+ to without spec decode when speculation is disabled for large
334
+ batch sizes.
335
+ """
336
+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
337
+ per_test_common_llm_kwargs ,
338
+ baseline_llm_kwargs , test_llm_kwargs ,
339
+ batch_size , output_len , seed )
340
+
341
+
342
+ if __name__ == "__main__" :
343
+ import pytest
344
+ pytest .main ([__file__ ])
0 commit comments