Skip to content

Commit c11013d

Browse files
morgendavezixi-qi
andauthored
[Meta] Llama4 EAGLE Support (#20591)
Signed-off-by: qizixi <qizixi@meta.com> Co-authored-by: qizixi <qizixi@meta.com>
1 parent 1eb2b9c commit c11013d

File tree

6 files changed

+257
-17
lines changed

6 files changed

+257
-17
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def main():
8484
gpu_memory_utilization=0.8,
8585
speculative_config=speculative_config,
8686
disable_log_stats=False,
87+
max_model_len=16384,
8788
)
8889

8990
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)

tests/models/registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,11 @@ def check_available_online(
465465
trust_remote_code=True,
466466
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
467467
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
468+
"EagleLlama4ForCausalLM": _HfExamplesInfo(
469+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
470+
trust_remote_code=True,
471+
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
472+
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
468473
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
469474
trust_remote_code=True,
470475
is_available_online=False,

tests/models/test_initialization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
3636
"KimiVLForConditionalGeneration"):
3737
pytest.skip("Avoid OOM")
3838

39+
if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
40+
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
41+
from vllm.model_executor.models.registry import ModelRegistry
42+
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)
43+
3944
# Avoid OOM and reduce initialization time by only using 1 layer
4045
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
4146
hf_config.update(model_info.hf_overrides)

tests/v1/e2e/test_spec_decode.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from typing import Any
77

88
import pytest
9+
import torch
910

1011
from vllm import LLM, SamplingParams
12+
from vllm.distributed import cleanup_dist_env_and_memory
1113

1214

1315
@pytest.fixture
@@ -53,14 +55,6 @@ def model_name():
5355
return "meta-llama/Llama-3.1-8B-Instruct"
5456

5557

56-
def eagle_model_name():
57-
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
58-
59-
60-
def eagle3_model_name():
61-
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
62-
63-
6458
def test_ngram_correctness(
6559
monkeypatch: pytest.MonkeyPatch,
6660
test_prompts: list[list[dict[str, Any]]],
@@ -77,6 +71,8 @@ def test_ngram_correctness(
7771
ref_llm = LLM(model=model_name, max_model_len=1024)
7872
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
7973
del ref_llm
74+
torch.cuda.empty_cache()
75+
cleanup_dist_env_and_memory()
8076

8177
spec_llm = LLM(
8278
model=model_name,
@@ -103,34 +99,50 @@ def test_ngram_correctness(
10399
# Upon failure, inspect the outputs to check for inaccuracy.
104100
assert matches > int(0.7 * len(ref_outputs))
105101
del spec_llm
106-
107-
108-
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
102+
torch.cuda.empty_cache()
103+
cleanup_dist_env_and_memory()
104+
105+
106+
@pytest.mark.parametrize("model_setup", [
107+
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
108+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
109+
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
110+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
111+
pytest.param(
112+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
113+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
114+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
115+
],
116+
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
109117
def test_eagle_correctness(
110118
monkeypatch: pytest.MonkeyPatch,
111119
test_prompts: list[list[dict[str, Any]]],
112120
sampling_config: SamplingParams,
113-
model_name: str,
114-
use_eagle3: bool,
121+
model_setup: tuple[str, str, str, int],
115122
):
116123
'''
117124
Compare the outputs of a original LLM and a speculative LLM
118125
should be the same when using eagle speculative decoding.
126+
model_setup: (method, model_name, eagle_model_name, tp_size)
119127
'''
120128
with monkeypatch.context() as m:
121129
m.setenv("VLLM_USE_V1", "1")
130+
method, model_name, spec_model_name, tp_size = model_setup
122131

123-
ref_llm = LLM(model=model_name, max_model_len=2048)
132+
ref_llm = LLM(model=model_name,
133+
max_model_len=2048,
134+
tensor_parallel_size=tp_size)
124135
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
125136
del ref_llm
137+
torch.cuda.empty_cache()
138+
cleanup_dist_env_and_memory()
126139

127-
spec_model_name = eagle3_model_name(
128-
) if use_eagle3 else eagle_model_name()
129140
spec_llm = LLM(
130141
model=model_name,
131142
trust_remote_code=True,
143+
tensor_parallel_size=tp_size,
132144
speculative_config={
133-
"method": "eagle3" if use_eagle3 else "eagle",
145+
"method": method,
134146
"model": spec_model_name,
135147
"num_speculative_tokens": 3,
136148
"max_model_len": 2048,
@@ -152,3 +164,5 @@ def test_eagle_correctness(
152164
# Upon failure, inspect the outputs to check for inaccuracy.
153165
assert matches > int(0.66 * len(ref_outputs))
154166
del spec_llm
167+
torch.cuda.empty_cache()
168+
cleanup_dist_env_and_memory()
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
4+
# All rights reserved.
5+
#
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
from collections.abc import Iterable
20+
from typing import Optional
21+
22+
import torch
23+
import torch.nn as nn
24+
25+
from vllm.compilation.decorators import support_torch_compile
26+
from vllm.config import VllmConfig
27+
from vllm.distributed.parallel_state import get_pp_group
28+
from vllm.logger import init_logger
29+
from vllm.model_executor.layers.layernorm import RMSNorm
30+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31+
from vllm.model_executor.layers.quantization.base_config import (
32+
QuantizationConfig)
33+
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
34+
from vllm.model_executor.layers.vocab_parallel_embedding import (
35+
VocabParallelEmbedding)
36+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37+
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
38+
Llama4ForCausalLM)
39+
from vllm.model_executor.models.utils import extract_layer_index
40+
41+
from .utils import AutoWeightsLoader, maybe_prefix
42+
43+
logger = init_logger(__name__)
44+
45+
46+
@support_torch_compile
47+
class LlamaModel(nn.Module):
48+
49+
def __init__(
50+
self,
51+
*,
52+
vllm_config: VllmConfig,
53+
prefix: str = "",
54+
start_layer_id: int = 0,
55+
quant_config: Optional[QuantizationConfig] = None,
56+
) -> None:
57+
super().__init__()
58+
self.config = (
59+
vllm_config.speculative_config.draft_model_config.hf_config)
60+
self.validate_and_update_config(start_layer_id, quant_config)
61+
self.vocab_size = self.config.vocab_size
62+
self.embed_tokens = VocabParallelEmbedding(
63+
self.config.vocab_size,
64+
self.config.hidden_size,
65+
prefix=maybe_prefix(prefix, "embed_tokens"),
66+
)
67+
68+
self.layers = nn.ModuleList([
69+
Llama4DecoderLayer(
70+
self.config,
71+
quant_config=quant_config,
72+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
73+
) for i in range(self.config.num_hidden_layers)
74+
])
75+
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
76+
self.config.hidden_size,
77+
bias=False)
78+
self.norm = RMSNorm(self.config.hidden_size,
79+
eps=self.config.rms_norm_eps)
80+
81+
def forward(
82+
self,
83+
input_ids: Optional[torch.Tensor],
84+
positions: torch.Tensor,
85+
hidden_states: torch.Tensor,
86+
) -> tuple[torch.Tensor, torch.Tensor]:
87+
input_embeds = self.embed_tokens(input_ids)
88+
hidden_states = self.fc(
89+
torch.cat((input_embeds, hidden_states), dim=-1))
90+
residual = None
91+
for layer in self.layers:
92+
hidden_states, residual = layer(
93+
positions,
94+
hidden_states,
95+
residual,
96+
)
97+
hidden_states, _ = self.norm(hidden_states, residual)
98+
return hidden_states, hidden_states
99+
100+
def load_weights(self, weights: Iterable[tuple[str,
101+
torch.Tensor]]) -> set[str]:
102+
stacked_params_mapping = [
103+
# (param_name, shard_name, shard_id)
104+
(".qkv_proj", ".q_proj", "q"),
105+
(".qkv_proj", ".k_proj", "k"),
106+
(".qkv_proj", ".v_proj", "v"),
107+
(".gate_up_proj", ".gate_proj", 0),
108+
(".gate_up_proj", ".up_proj", 1),
109+
]
110+
params_dict = dict(self.named_parameters())
111+
loaded_params: set[str] = set()
112+
for name, loaded_weight in weights:
113+
name = name.removeprefix("model.")
114+
for param_name, weight_name, shard_id in stacked_params_mapping:
115+
if weight_name not in name:
116+
continue
117+
name = name.replace(weight_name, param_name)
118+
param = params_dict[name]
119+
weight_loader = param.weight_loader
120+
weight_loader(param, loaded_weight, shard_id)
121+
break
122+
else:
123+
# if PP disabled then draft will share embed with target
124+
if get_pp_group().world_size == 1 and \
125+
"embed_tokens." in name:
126+
continue
127+
param = params_dict[name]
128+
weight_loader = getattr(param, "weight_loader",
129+
default_weight_loader)
130+
weight_loader(param, loaded_weight)
131+
loaded_params.add(name)
132+
for name in params_dict:
133+
# if PP disabled then draft will share embed with target
134+
if get_pp_group().world_size == 1 and \
135+
"embed_tokens." in name:
136+
continue
137+
assert name in loaded_params, f"{name} is not loaded!"
138+
return loaded_params
139+
140+
def validate_and_update_config(
141+
self,
142+
start_layer_id: int,
143+
quant_config: Optional[QuantizationConfig] = None) -> None:
144+
# yoco and moe is not supported by draft model yet
145+
assert self.config.yoco_global_kv_layer is None
146+
assert self.config.yoco_local_kv_layer is None
147+
assert len(self.config.moe_layers) == 0
148+
# draft model layer index is increased by start_layer_id,
149+
# so we need to pad relevant configs accordingly
150+
self.config.no_rope_layers = [
151+
0
152+
] * start_layer_id + self.config.no_rope_layers
153+
# currently only TorchAO quantization is supported
154+
if isinstance(quant_config, TorchAOConfig):
155+
156+
def pad_layer_name(layer: str) -> str:
157+
layer_index = extract_layer_index(layer)
158+
return layer.replace(str(layer_index),
159+
str(layer_index + start_layer_id))
160+
161+
quant_config.torchao_config.module_fqn_to_config = {
162+
pad_layer_name(layer): quantization
163+
for layer, quantization in
164+
quant_config.torchao_config.module_fqn_to_config.items()
165+
}
166+
167+
168+
class EagleLlama4ForCausalLM(Llama4ForCausalLM):
169+
170+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
171+
nn.Module.__init__(self)
172+
self.config = (
173+
vllm_config.speculative_config.draft_model_config.hf_config)
174+
target_layer_num = vllm_config.model_config.get_num_layers(
175+
vllm_config.parallel_config)
176+
# draft model quantization config may differ from target model
177+
quant_config = VllmConfig.get_quantization_config(
178+
vllm_config.speculative_config.draft_model_config,
179+
vllm_config.load_config)
180+
self.model = LlamaModel(vllm_config=vllm_config,
181+
prefix="model",
182+
start_layer_id=target_layer_num,
183+
quant_config=quant_config)
184+
logit_scale = getattr(self.config, "logit_scale", 1.0)
185+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
186+
scale=logit_scale)
187+
188+
def forward(
189+
self,
190+
input_ids: torch.Tensor,
191+
positions: torch.Tensor,
192+
hidden_states: torch.Tensor,
193+
) -> tuple[torch.Tensor, torch.Tensor]:
194+
return self.model(input_ids, positions, hidden_states)
195+
196+
def load_weights(self, weights: Iterable[tuple[str,
197+
torch.Tensor]]) -> None:
198+
loader = AutoWeightsLoader(
199+
self,
200+
# lm_head is tied with target model (Llama4ForCausalLM)
201+
skip_prefixes=(["lm_head."]),
202+
)
203+
204+
model_weights = {}
205+
weights = [
206+
self.permute_qk_weight_for_rotary(name, loaded_weight)
207+
for name, loaded_weight in weights
208+
]
209+
for name, loaded_weight in weights:
210+
if "lm_head" not in name:
211+
name = "model." + name
212+
model_weights[name] = loaded_weight
213+
214+
loader.load_weights(model_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@
244244
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
245245
"EAGLEModel": ("eagle", "EAGLE"),
246246
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
247+
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
247248
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
248249
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
249250
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),

0 commit comments

Comments
 (0)