Skip to content

[Meta] Official Eagle mm support, first enablement on llama4 #20788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,38 @@
from argparse import ArgumentParser as FlexibleArgumentParser


QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
]


def get_custom_mm_prompts(num_prompts):
prompts = []
for url in IMAGE_URLS:
prompts.append(
[
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": QUESTION},
]
)
if num_prompts > len(IMAGE_URLS):
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
Expand All @@ -35,6 +67,7 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
return parser.parse_args()


Expand All @@ -46,12 +79,18 @@ def main():
if args.model_dir is None:
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model should be Llama4 for multimodal mode

tokenizer = AutoTokenizer.from_pretrained(model_dir)

prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
]
args.custom_skip_chat_template = True

if not args.custom_mm_prompts:
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False)
for prompt in prompts
]
else:
prompts = get_custom_mm_prompts(args.num_prompts)

if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
Expand Down Expand Up @@ -85,10 +124,17 @@ def main():
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
if not args.custom_mm_prompts:
outputs = llm.generate(
prompt_token_ids=prompt_ids, sampling_params=sampling_params
)
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)

# print the generated text
if args.print_output:
Expand Down
57 changes: 42 additions & 15 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,28 @@
import torch

from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory


@pytest.fixture
def test_prompts():
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
num_prompts = 100
if mm_enabled:
prompt_types.append("mm")
num_prompts = 10
prompts = []

random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")

# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
prompt: str | list[dict[str, Any]] = ""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt: str | list[dict[str, Any]] = ""
prompt: Union[str, list[dict[str, Any]]] = ""

Don't break Python 3.9 users

if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
Expand All @@ -38,6 +43,21 @@ def test_prompts():
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [{
"type": "image_url",
"image_url": {
"url":
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}]
prompt = [
*placeholders,
{
"type": "text",
"text": "The meaning of the image is"
},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
Expand Down Expand Up @@ -103,23 +123,30 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()


@pytest.mark.parametrize("model_setup", [
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
@pytest.mark.parametrize(
"model_setup,mm_enabled", [
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
super().__init__()

self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
Expand Down
35 changes: 31 additions & 4 deletions vllm/model_executor/models/llama4_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.inputs import NestedTensors

from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings

logger = init_logger(__name__)

Expand Down Expand Up @@ -78,15 +79,23 @@ def __init__(
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
torch.cat((inputs_embeds, hidden_states), dim=-1))
residual = None
for layer in self.layers:
hidden_states, residual = layer(
Expand Down Expand Up @@ -190,8 +199,9 @@ def forward(
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
return self.model(input_ids, positions, hidden_states, inputs_embeds)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
Expand All @@ -212,3 +222,20 @@ def load_weights(self, weights: Iterable[tuple[str,
model_weights[name] = loaded_weight

loader.load_weights(model_weights.items())

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)

if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)

return inputs_embeds
6 changes: 6 additions & 0 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -148,7 +149,12 @@ def forward(
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ def forward(
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)

def compute_logits(
Expand Down
Loading