Skip to content

[Meta] Official Eagle mm support, first enablement on llama4 #20788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 54 additions & 7 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,38 @@
from argparse import ArgumentParser as FlexibleArgumentParser


QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
]


def get_custom_mm_prompts(num_prompts):
prompts = []
for url in IMAGE_URLS:
prompts.append(
[
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": QUESTION},
]
)
if num_prompts > len(IMAGE_URLS):
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
Expand All @@ -35,6 +67,7 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
return parser.parse_args()


Expand All @@ -46,12 +79,18 @@ def main():
if args.model_dir is None:
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)

prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
]
args.custom_skip_chat_template = True

if not args.custom_mm_prompts:
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False)
for prompt in prompts
]
else:
prompts = get_custom_mm_prompts(args.num_prompts)

if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
Expand Down Expand Up @@ -84,10 +123,18 @@ def main():
gpu_memory_utilization=0.8,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
if not args.custom_mm_prompts:
outputs = llm.generate(
prompt_token_ids=prompt_ids, sampling_params=sampling_params
)
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)

# print the generated text
if args.print_output:
Expand Down
7 changes: 6 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,11 @@ def check_available_online(
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
"EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True,
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True,
is_available_online=False,
Expand Down Expand Up @@ -500,4 +505,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
raise ValueError(f"No example model defined for {model_id}")


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
5 changes: 5 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
"KimiVLForConditionalGeneration"):
pytest.skip("Avoid OOM")

if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)

# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)
Expand Down
83 changes: 62 additions & 21 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@
from typing import Any

import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory


@pytest.fixture
def test_prompts():
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
num_prompts = 100
if mm_enabled:
prompt_types.append("mm")
num_prompts = 10
prompts = []

random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")

# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
prompt: str | list[dict[str, Any]] = ""
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
Expand All @@ -36,6 +43,21 @@ def test_prompts():
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [{
"type": "image_url",
"image_url": {
"url":
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}]
prompt = [
*placeholders,
{
"type": "text",
"text": "The meaning of the image is"
},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
Expand All @@ -53,14 +75,6 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"


def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"


def eagle3_model_name():
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
Expand All @@ -77,6 +91,8 @@ def test_ngram_correctness(
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_llm = LLM(
model=model_name,
Expand All @@ -103,34 +119,57 @@ def test_ngram_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm


@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
"model_setup,mm_enabled", [
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
method, model_name, spec_model_name, tp_size = model_setup

ref_llm = LLM(model=model_name, max_model_len=2048)
ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
Expand All @@ -152,3 +191,5 @@ def test_eagle_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
1 change: 1 addition & 0 deletions vllm/model_executor/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def __init__(
super().__init__()

self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
Expand Down
Loading