Skip to content

[Not for merge] Unshift eagle prefill #21008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
88 changes: 81 additions & 7 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,38 @@
from argparse import ArgumentParser as FlexibleArgumentParser


QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
]


def get_custom_mm_prompts(num_prompts):
prompts = []
for url in IMAGE_URLS:
prompts.append(
[
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": QUESTION},
]
)
if num_prompts > len(IMAGE_URLS):
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)

return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]


def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
Expand All @@ -35,6 +67,20 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument(
"--no-prefill-token-shift",
dest="prefill_token_shift",
action="store_false",
help="Disable prefill token shift (default: enabled)",
)
parser.add_argument("--target_kv_layer_copy_from", type=int, default=-1)
parser.add_argument(
"--draft_kv_layer_copy_to",
type=str,
default="",
help="comma separated list of layer indices to copy to",
)
return parser.parse_args()


Expand All @@ -46,12 +92,18 @@ def main():
if args.model_dir is None:
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)

prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
]
args.custom_skip_chat_template = True

if not args.custom_mm_prompts:
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False)
for prompt in prompts
]
else:
prompts = get_custom_mm_prompts(args.num_prompts)

if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
Expand All @@ -60,10 +112,24 @@ def main():

elif args.method == "eagle3" and eagle_dir is None:
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
target_kv_layer_copy_from = args.target_kv_layer_copy_from
draft_kv_layers_copy_to = (
[int(layer) for layer in args.draft_kv_layer_copy_to.split(",")]
if args.draft_kv_layer_copy_to
else None
)
kv_sharing_mapping = None
if args.target_kv_layer_copy_from >= 0 and draft_kv_layers_copy_to:
kv_sharing_mapping = {
f"{layer}": f"{target_kv_layer_copy_from}"
for layer in draft_kv_layers_copy_to
}
speculative_config = {
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"prefill_token_shift": args.prefill_token_shift,
"kv_sharing_mapping": kv_sharing_mapping,
}
elif args.method == "ngram":
speculative_config = {
Expand All @@ -84,10 +150,18 @@ def main():
gpu_memory_utilization=0.8,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
if not args.custom_mm_prompts:
outputs = llm.generate(
prompt_token_ids=prompt_ids, sampling_params=sampling_params
)
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)

# print the generated text
if args.print_output:
Expand Down
7 changes: 6 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,11 @@ def check_available_online(
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
"EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True,
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True,
is_available_online=False,
Expand Down Expand Up @@ -500,4 +505,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
raise ValueError(f"No example model defined for {model_id}")


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
5 changes: 5 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
"KimiVLForConditionalGeneration"):
pytest.skip("Avoid OOM")

if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)

# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)
Expand Down
101 changes: 78 additions & 23 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@
from typing import Any

import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory


@pytest.fixture
def test_prompts():
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
num_prompts = 100
if mm_enabled:
prompt_types.append("mm")
num_prompts = 10
prompts = []

random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")

# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
prompt: str | list[dict[str, Any]] = ""
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
Expand All @@ -36,6 +43,21 @@ def test_prompts():
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [{
"type": "image_url",
"image_url": {
"url":
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}]
prompt = [
*placeholders,
{
"type": "text",
"text": "The meaning of the image is"
},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
Expand All @@ -53,14 +75,6 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"


def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"


def eagle3_model_name():
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
Expand All @@ -77,6 +91,8 @@ def test_ngram_correctness(
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_llm = LLM(
model=model_name,
Expand All @@ -103,39 +119,76 @@ def test_ngram_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm


@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize("model_setup,mm_enabled,prefill_shift", [
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False, True),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False, True),
pytest.param(
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False, True),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, True),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False, False),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, False),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=[
"llama3_eagle", "llama3_eagle3", "llama4_eagle",
"llama4_eagle_mm", "llama4_eagle_no_shift",
"llama4_eagle_mm_no_shift"
])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
prefill_shift: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
method, model_name, spec_model_name, tp_size = model_setup

ref_llm = LLM(model=model_name, max_model_len=2048)
max_model_len = 2048 if not mm_enabled else 4096
ref_llm = LLM(model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
"max_model_len": max_model_len,
"prefill_token_shift": prefill_shift,
},
max_model_len=2048,
max_model_len=max_model_len,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
Expand All @@ -152,3 +205,5 @@ def test_eagle_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2551,6 +2551,15 @@ class SpeculativeConfig:
ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""

# Shift prefill tokens for draft, only used in eagle
prefill_token_shift: bool = True
"""Shift tokens during draft prefill or not"""

# Config for kv sharing, map from base model layer to draft layer
# Key is draft layer, value is base layer
kv_sharing_mapping: SkipValidation[dict[str, str]] = None # type: ignore
"""KV copy mapping for prefill stage from base to draft"""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -2937,6 +2946,11 @@ def num_lookahead_slots(self) -> int:
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp")

def eagle_shift_prefill_token(self) -> bool:
if self.use_eagle():
return self.prefill_token_shift
return False

def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
Expand Down
5 changes: 4 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_DECODE_ONLY_ATTN: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -953,7 +954,9 @@ def get_vllm_port() -> Optional[int]:
# generations on machines < 100 for compressed-tensors
# models
"VLLM_USE_NVFP4_CT_EMULATIONS":
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")))
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),
"VLLM_DECODE_ONLY_ATTN":
lambda: os.environ.get("VLLM_DECODE_ONLY_ATTN", "0") == "1"
}

# --8<-- [end:env-vars-definition]
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def __init__(
super().__init__()

self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
Expand Down
Loading