Skip to content

[Meta] Llama4 EAGLE Support #20591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def main():
gpu_memory_utilization=0.8,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed due to OOM issue?

Copy link
Collaborator Author

@morgendave morgendave Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might not be needed, it is just from our internal test, though I think default is smaller than this?
Yes this is for OOM as original length with BF16 will be too big

)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
Expand Down
7 changes: 6 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,11 @@ def check_available_online(
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
"EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True,
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True,
is_available_online=False,
Expand Down Expand Up @@ -500,4 +505,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
raise ValueError(f"No example model defined for {model_id}")


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
5 changes: 5 additions & 0 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
"KimiVLForConditionalGeneration"):
pytest.skip("Avoid OOM")

if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)

# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)
Expand Down
48 changes: 31 additions & 17 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing import Any

import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory


@pytest.fixture
Expand Down Expand Up @@ -53,14 +55,6 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"


def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"


def eagle3_model_name():
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
Expand All @@ -77,6 +71,8 @@ def test_ngram_correctness(
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_llm = LLM(
model=model_name,
Expand All @@ -103,34 +99,50 @@ def test_ngram_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm


@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize("model_setup", [
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
model_setup: tuple[str, str, str, int],
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
method, model_name, spec_model_name, tp_size = model_setup

ref_llm = LLM(model=model_name, max_model_len=2048)
ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
Expand All @@ -152,3 +164,5 @@ def test_eagle_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
214 changes: 214 additions & 0 deletions vllm/model_executor/models/llama4_eagle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add the copyright from Meta side?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, thanks for the suggestion

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterable
from typing import Optional

import torch
import torch.nn as nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


@support_torch_compile
class LlamaModel(nn.Module):

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
start_layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = (
vllm_config.speculative_config.draft_model_config.hf_config)
self.validate_and_update_config(start_layer_id, quant_config)
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
Llama4DecoderLayer(
self.config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
self.config.hidden_size,
bias=False)
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
Comment on lines +81 to +86
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will MM input support be added later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's in #20591

input_embeds = self.embed_tokens(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for layer in self.layers:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, hidden_states

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
name = name.removeprefix("model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and \
"embed_tokens." in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
for name in params_dict:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and \
"embed_tokens." in name:
continue
assert name in loaded_params, f"{name} is not loaded!"
return loaded_params

def validate_and_update_config(
self,
start_layer_id: int,
quant_config: Optional[QuantizationConfig] = None) -> None:
# yoco and moe is not supported by draft model yet
assert self.config.yoco_global_kv_layer is None
assert self.config.yoco_local_kv_layer is None
assert len(self.config.moe_layers) == 0
# draft model layer index is increased by start_layer_id,
# so we need to pad relevant configs accordingly
self.config.no_rope_layers = [
0
] * start_layer_id + self.config.no_rope_layers
# currently only TorchAO quantization is supported
if isinstance(quant_config, TorchAOConfig):

def pad_layer_name(layer: str) -> str:
layer_index = extract_layer_index(layer)
return layer.replace(str(layer_index),
str(layer_index + start_layer_id))

quant_config.torchao_config.module_fqn_to_config = {
pad_layer_name(layer): quantization
for layer, quantization in
quant_config.torchao_config.module_fqn_to_config.items()
}


class EagleLlama4ForCausalLM(Llama4ForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method of EagleLlama4ForCausalLM calls nn.Module.__init__(self) directly, bypassing the initializer of its base class, Llama4ForCausalLM. This is a fragile design because if Llama4ForCausalLM's __init__ sets up important state that is used by inherited methods (like permute_qk_weight_for_rotary), this implementation could break.

To improve maintainability, consider either using composition over inheritance or ensuring a proper call to super().__init__(...).

self.config = (
vllm_config.speculative_config.draft_model_config.hf_config)
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
# draft model quantization config may differ from target model
quant_config = VllmConfig.get_quantization_config(
vllm_config.speculative_config.draft_model_config,
vllm_config.load_config)
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num,
quant_config=quant_config)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
loader = AutoWeightsLoader(
self,
# lm_head is tied with target model (Llama4ForCausalLM)
skip_prefixes=(["lm_head."]),
)

model_weights = {}
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight

loader.load_weights(model_weights.items())
Comment on lines +204 to +214
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for processing weights is not memory-efficient. It first creates a list of all permuted weights, and then a dictionary of these weights, both of which can consume a large amount of memory for large models.

A more memory-efficient approach is to use a generator to process the weights one by one. This avoids loading all weights into memory at once.

        def _processed_weights():
            for name, loaded_weight in weights:
                name, loaded_weight = self.permute_qk_weight_for_rotary(
                    name, loaded_weight)
                if "lm_head" not in name:
                    name = "model." + name
                yield name, loaded_weight

        loader.load_weights(_processed_weights())

1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
"EAGLEModel": ("eagle", "EAGLE"),
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
Expand Down