Skip to content

Commit 2f8fe49

Browse files
committed
Add llama4 eagle to e2e test for spec decode
1 parent 4178543 commit 2f8fe49

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from typing import Any
77

88
import pytest
9+
import torch
910

1011
from vllm import LLM, SamplingParams
12+
from vllm.distributed import cleanup_dist_env_and_memory
1113

1214

1315
@pytest.fixture
@@ -53,14 +55,6 @@ def model_name():
5355
return "meta-llama/Llama-3.1-8B-Instruct"
5456

5557

56-
def eagle_model_name():
57-
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
58-
59-
60-
def eagle3_model_name():
61-
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
62-
63-
6458
def test_ngram_correctness(
6559
monkeypatch: pytest.MonkeyPatch,
6660
test_prompts: list[list[dict[str, Any]]],
@@ -77,6 +71,8 @@ def test_ngram_correctness(
7771
ref_llm = LLM(model=model_name, max_model_len=1024)
7872
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
7973
del ref_llm
74+
torch.cuda.empty_cache()
75+
cleanup_dist_env_and_memory()
8076

8177
spec_llm = LLM(
8278
model=model_name,
@@ -103,34 +99,48 @@ def test_ngram_correctness(
10399
# Upon failure, inspect the outputs to check for inaccuracy.
104100
assert matches > int(0.7 * len(ref_outputs))
105101
del spec_llm
106-
107-
108-
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
102+
torch.cuda.empty_cache()
103+
cleanup_dist_env_and_memory()
104+
105+
106+
@pytest.mark.parametrize("model_setup", [
107+
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
108+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
109+
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
110+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
111+
("eagle", "/home/zhiweiz/local/models/scout_base_HF_20250605_201140",
112+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
113+
],
114+
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
109115
def test_eagle_correctness(
110116
monkeypatch: pytest.MonkeyPatch,
111117
test_prompts: list[list[dict[str, Any]]],
112118
sampling_config: SamplingParams,
113-
model_name: str,
114-
use_eagle3: bool,
119+
model_setup: tuple[str, str, str, int],
115120
):
116121
'''
117122
Compare the outputs of a original LLM and a speculative LLM
118123
should be the same when using eagle speculative decoding.
124+
model_setup: (method, model_name, eagle_model_name, tp_size)
119125
'''
120126
with monkeypatch.context() as m:
121127
m.setenv("VLLM_USE_V1", "1")
128+
method, model_name, spec_model_name, tp_size = model_setup
122129

123-
ref_llm = LLM(model=model_name, max_model_len=2048)
130+
ref_llm = LLM(model=model_name,
131+
max_model_len=2048,
132+
tensor_parallel_size=tp_size)
124133
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
125134
del ref_llm
135+
torch.cuda.empty_cache()
136+
cleanup_dist_env_and_memory()
126137

127-
spec_model_name = eagle3_model_name(
128-
) if use_eagle3 else eagle_model_name()
129138
spec_llm = LLM(
130139
model=model_name,
131140
trust_remote_code=True,
141+
tensor_parallel_size=tp_size,
132142
speculative_config={
133-
"method": "eagle3" if use_eagle3 else "eagle",
143+
"method": method,
134144
"model": spec_model_name,
135145
"num_speculative_tokens": 3,
136146
"max_model_len": 2048,
@@ -152,3 +162,5 @@ def test_eagle_correctness(
152162
# Upon failure, inspect the outputs to check for inaccuracy.
153163
assert matches > int(0.66 * len(ref_outputs))
154164
del spec_llm
165+
torch.cuda.empty_cache()
166+
cleanup_dist_env_and_memory()

vllm/model_executor/models/llama4_eagle.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
23

34
from collections.abc import Iterable
45
from typing import Optional
@@ -177,10 +178,8 @@ def forward(
177178
) -> tuple[torch.Tensor, torch.Tensor]:
178179
return self.model(input_ids, positions, hidden_states)
179180

180-
def load_weights(
181-
self,
182-
weights: Iterable[tuple[str, torch.Tensor]]
183-
) -> None:
181+
def load_weights(self, weights: Iterable[tuple[str,
182+
torch.Tensor]]) -> None:
184183
loader = AutoWeightsLoader(
185184
self,
186185
# lm_head is tied with target model (Llama4ForCausalLM)

0 commit comments

Comments
 (0)