Skip to content

Commit 8df05d0

Browse files
committed
eagle mm support, primarily llama4
Signed-off-by: morgendave <morgendave@gmail.com>
1 parent 4238b3a commit 8df05d0

File tree

8 files changed

+187
-34
lines changed

8 files changed

+187
-34
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,38 @@
1313
from argparse import ArgumentParser as FlexibleArgumentParser
1414

1515

16+
QUESTION = "What is the content of each image?"
17+
IMAGE_URLS = [
18+
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
19+
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
20+
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
21+
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
22+
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
23+
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
24+
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
25+
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
26+
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
27+
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
28+
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
29+
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
30+
]
31+
32+
33+
def get_custom_mm_prompts(num_prompts):
34+
prompts = []
35+
for url in IMAGE_URLS:
36+
prompts.append(
37+
[
38+
{"type": "image_url", "image_url": {"url": url}},
39+
{"type": "text", "text": QUESTION},
40+
]
41+
)
42+
if num_prompts > len(IMAGE_URLS):
43+
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
44+
45+
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
46+
47+
1648
def parse_args():
1749
parser = FlexibleArgumentParser()
1850
add_dataset_parser(parser)
@@ -35,6 +67,7 @@ def parse_args():
3567
parser.add_argument("--output-len", type=int, default=256)
3668
parser.add_argument("--model-dir", type=str, default=None)
3769
parser.add_argument("--eagle-dir", type=str, default=None)
70+
parser.add_argument("--custom-mm-prompts", action="store_true")
3871
return parser.parse_args()
3972

4073

@@ -46,12 +79,18 @@ def main():
4679
if args.model_dir is None:
4780
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
4881
tokenizer = AutoTokenizer.from_pretrained(model_dir)
49-
50-
prompts = get_samples(args, tokenizer)
51-
# add_special_tokens is False to avoid adding bos twice when using chat templates
52-
prompt_ids = [
53-
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
54-
]
82+
args.custom_skip_chat_template = True
83+
84+
if not args.custom_mm_prompts:
85+
prompts = get_samples(args, tokenizer)
86+
# add_special_tokens is False to avoid adding bos twice
87+
# when using chat templates
88+
prompt_ids = [
89+
tokenizer.encode(prompt.prompt, add_special_tokens=False)
90+
for prompt in prompts
91+
]
92+
else:
93+
prompts = get_custom_mm_prompts(args.num_prompts)
5594

5695
if args.method == "eagle" or args.method == "eagle3":
5796
eagle_dir = args.eagle_dir
@@ -85,10 +124,17 @@ def main():
85124
speculative_config=speculative_config,
86125
disable_log_stats=False,
87126
max_model_len=16384,
127+
limit_mm_per_prompt={"image": 5},
128+
disable_chunked_mm_input=True,
88129
)
89130

90131
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
91-
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
132+
if not args.custom_mm_prompts:
133+
outputs = llm.generate(
134+
prompt_token_ids=prompt_ids, sampling_params=sampling_params
135+
)
136+
else:
137+
outputs = llm.chat(prompts, sampling_params=sampling_params)
92138

93139
# print the generated text
94140
if args.print_output:

tests/v1/e2e/test_spec_decode.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,28 @@
99
import torch
1010

1111
from vllm import LLM, SamplingParams
12+
from vllm.assets.base import VLLM_S3_BUCKET_URL
13+
from vllm.assets.image import VLM_IMAGES_DIR
1214
from vllm.distributed import cleanup_dist_env_and_memory
1315

1416

15-
@pytest.fixture
16-
def test_prompts():
17+
def get_test_prompts(mm_enabled: bool):
1718
prompt_types = ["repeat", "sentence"]
18-
num_prompts = 100
19+
if mm_enabled:
20+
prompt_types.append("mm")
21+
num_prompts = 10
1922
prompts = []
2023

2124
random.seed(0)
2225
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
26+
print(f"Prompt types: {random_prompt_type_choices}")
2327

2428
# Generate a mixed batch of prompts, some of which can be easily
2529
# predicted by n-gram matching and some which likely cannot.
2630
for kind in random_prompt_type_choices:
2731
word_choices = ["test", "temp", "hello", "where"]
2832
word = random.choice(word_choices)
33+
prompt: str | list[dict[str, Any]] = ""
2934
if kind == "repeat":
3035
prompt = f"""
3136
please repeat the word '{word}' 10 times.
@@ -38,6 +43,21 @@ def test_prompts():
3843
uses the word {word} at least once.
3944
give no other output than that simple sentence without quotes.
4045
"""
46+
elif kind == "mm":
47+
placeholders = [{
48+
"type": "image_url",
49+
"image_url": {
50+
"url":
51+
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
52+
},
53+
}]
54+
prompt = [
55+
*placeholders,
56+
{
57+
"type": "text",
58+
"text": "The meaning of the image is"
59+
},
60+
]
4161
else:
4262
raise ValueError(f"Unknown prompt type: {kind}")
4363
prompts.append([{"role": "user", "content": prompt}])
@@ -103,21 +123,26 @@ def test_ngram_correctness(
103123
cleanup_dist_env_and_memory()
104124

105125

106-
@pytest.mark.parametrize("model_setup", [
107-
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
108-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
109-
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
110-
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
111-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
112-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
113-
],
114-
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
126+
@pytest.mark.parametrize(
127+
"model_setup,mm_enabled", [
128+
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
129+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
130+
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
131+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
132+
(("eagle", "/home/zhiweiz/local/models/scout_base_HF_20250605_201140",
133+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False),
134+
(("eagle", "/home/zhiweiz/local/models/scout_base_HF_20250605_201140",
135+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True),
136+
],
137+
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
115138
def test_eagle_correctness(
116139
monkeypatch: pytest.MonkeyPatch,
117-
test_prompts: list[list[dict[str, Any]]],
118140
sampling_config: SamplingParams,
119141
model_setup: tuple[str, str, str, int],
142+
mm_enabled: bool,
120143
):
144+
# Generate test prompts inside the function instead of using fixture
145+
test_prompts = get_test_prompts(mm_enabled)
121146
'''
122147
Compare the outputs of a original LLM and a speculative LLM
123148
should be the same when using eagle speculative decoding.

vllm/model_executor/models/llama4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def __init__(
255255
super().__init__()
256256

257257
self.layer_idx = extract_layer_index(prefix)
258+
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
258259
self.hidden_size = config.hidden_size
259260
rope_theta = config.rope_theta
260261
rope_scaling = config.rope_scaling

vllm/model_executor/models/llama4_eagle.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@
3737
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
3838
Llama4ForCausalLM)
3939
from vllm.model_executor.models.utils import extract_layer_index
40+
from vllm.multimodal.inputs import NestedTensors
4041

41-
from .utils import AutoWeightsLoader, maybe_prefix
42+
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
4243

4344
logger = init_logger(__name__)
4445

@@ -78,15 +79,23 @@ def __init__(
7879
self.norm = RMSNorm(self.config.hidden_size,
7980
eps=self.config.rms_norm_eps)
8081

82+
def get_input_embeddings(
83+
self,
84+
input_ids: torch.Tensor,
85+
) -> torch.Tensor:
86+
return self.embed_tokens(input_ids)
87+
8188
def forward(
8289
self,
8390
input_ids: Optional[torch.Tensor],
8491
positions: torch.Tensor,
8592
hidden_states: torch.Tensor,
93+
inputs_embeds: Optional[torch.Tensor] = None,
8694
) -> tuple[torch.Tensor, torch.Tensor]:
87-
input_embeds = self.embed_tokens(input_ids)
95+
if inputs_embeds is None:
96+
inputs_embeds = self.get_input_embeddings(input_ids)
8897
hidden_states = self.fc(
89-
torch.cat((input_embeds, hidden_states), dim=-1))
98+
torch.cat((inputs_embeds, hidden_states), dim=-1))
9099
residual = None
91100
for layer in self.layers:
92101
hidden_states, residual = layer(
@@ -190,8 +199,9 @@ def forward(
190199
input_ids: torch.Tensor,
191200
positions: torch.Tensor,
192201
hidden_states: torch.Tensor,
202+
inputs_embeds: Optional[torch.Tensor] = None,
193203
) -> tuple[torch.Tensor, torch.Tensor]:
194-
return self.model(input_ids, positions, hidden_states)
204+
return self.model(input_ids, positions, hidden_states, inputs_embeds)
195205

196206
def load_weights(self, weights: Iterable[tuple[str,
197207
torch.Tensor]]) -> None:
@@ -212,3 +222,20 @@ def load_weights(self, weights: Iterable[tuple[str,
212222
model_weights[name] = loaded_weight
213223

214224
loader.load_weights(model_weights.items())
225+
226+
def get_input_embeddings(
227+
self,
228+
input_ids: torch.Tensor,
229+
multimodal_embeddings: Optional[NestedTensors] = None,
230+
) -> torch.Tensor:
231+
inputs_embeds = self.model.get_input_embeddings(input_ids)
232+
233+
if multimodal_embeddings is not None:
234+
inputs_embeds = merge_multimodal_embeddings(
235+
input_ids,
236+
inputs_embeds,
237+
multimodal_embeddings,
238+
self.config.image_token_index,
239+
)
240+
241+
return inputs_embeds

vllm/model_executor/models/llama_eagle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Iterable
5+
from typing import Optional
56

67
import torch
78
import torch.nn as nn
@@ -148,6 +149,7 @@ def forward(
148149
input_ids: torch.Tensor,
149150
positions: torch.Tensor,
150151
hidden_states: torch.Tensor,
152+
inputs_embeds: Optional[torch.Tensor] = None,
151153
) -> tuple[torch.Tensor, torch.Tensor]:
152154
return self.model(input_ids, positions, hidden_states)
153155

vllm/model_executor/models/llama_eagle3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def forward(
202202
input_ids: torch.Tensor,
203203
positions: torch.Tensor,
204204
hidden_states: torch.Tensor,
205+
inputs_embeds: Optional[torch.Tensor] = None,
205206
) -> tuple[torch.Tensor, torch.Tensor]:
206207
return self.model(input_ids, positions, hidden_states)
207208

0 commit comments

Comments
 (0)