Skip to content

Commit f1c861d

Browse files
zixi-qimorgendave
authored andcommitted
[Meta] Llama4 EAGLE Support Co-authored-by: Zixi Qi <qizixi@meta.com>
Signed-off-by: qizixi <qizixi@meta.com>
1 parent 042d131 commit f1c861d

File tree

3 files changed

+200
-1
lines changed

3 files changed

+200
-1
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,10 @@ def main():
8181
tensor_parallel_size=args.tp,
8282
enable_chunked_prefill=args.enable_chunked_prefill,
8383
enforce_eager=args.enforce_eager,
84-
gpu_memory_utilization=0.8,
84+
gpu_memory_utilization=0.7,
8585
speculative_config=speculative_config,
8686
disable_log_stats=False,
87+
max_model_len=16384,
8788
)
8889

8990
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from collections.abc import Iterable
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
from vllm.compilation.decorators import support_torch_compile
9+
from vllm.config import VllmConfig
10+
from vllm.distributed.parallel_state import get_pp_group
11+
from vllm.logger import init_logger
12+
from vllm.model_executor.layers.layernorm import RMSNorm
13+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
14+
from vllm.model_executor.layers.quantization.base_config import (
15+
QuantizationConfig)
16+
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
17+
from vllm.model_executor.layers.vocab_parallel_embedding import (
18+
VocabParallelEmbedding)
19+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20+
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
21+
Llama4ForCausalLM)
22+
from vllm.model_executor.models.utils import extract_layer_index
23+
24+
from .utils import AutoWeightsLoader, maybe_prefix
25+
from typing import Optional
26+
27+
logger = init_logger(__name__)
28+
29+
30+
@support_torch_compile
31+
class LlamaModel(nn.Module):
32+
33+
def __init__(
34+
self,
35+
*,
36+
vllm_config: VllmConfig,
37+
prefix: str = "",
38+
start_layer_id: int = 0,
39+
quant_config: Optional[QuantizationConfig] = None,
40+
) -> None:
41+
super().__init__()
42+
self.config = vllm_config. \
43+
speculative_config.draft_model_config.hf_config
44+
self.validate_and_update_config(start_layer_id, quant_config)
45+
self.vocab_size = self.config.vocab_size
46+
self.embed_tokens = VocabParallelEmbedding(
47+
self.config.vocab_size,
48+
self.config.hidden_size,
49+
prefix=maybe_prefix(prefix, "embed_tokens"),
50+
)
51+
52+
self.layers = nn.ModuleList([
53+
Llama4DecoderLayer(
54+
self.config,
55+
quant_config=quant_config,
56+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
57+
) for i in range(self.config.num_hidden_layers)
58+
])
59+
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
60+
self.config.hidden_size,
61+
bias=False)
62+
self.norm = RMSNorm(self.config.hidden_size,
63+
eps=self.config.rms_norm_eps)
64+
65+
def forward(
66+
self,
67+
input_ids: Optional[torch.Tensor],
68+
positions: torch.Tensor,
69+
hidden_states: torch.Tensor,
70+
) -> tuple[torch.Tensor, torch.Tensor]:
71+
input_embeds = self.embed_tokens(input_ids)
72+
hidden_states = self.fc(
73+
torch.cat((input_embeds, hidden_states), dim=-1))
74+
residual = None
75+
for layer in self.layers:
76+
hidden_states, residual = layer(
77+
positions,
78+
hidden_states,
79+
residual,
80+
)
81+
hidden_states, _ = self.norm(hidden_states, residual)
82+
return hidden_states, hidden_states
83+
84+
def load_weights(self, weights: Iterable[tuple[str,
85+
torch.Tensor]]) -> set[str]:
86+
stacked_params_mapping = [
87+
# (param_name, shard_name, shard_id)
88+
(".qkv_proj", ".q_proj", "q"),
89+
(".qkv_proj", ".k_proj", "k"),
90+
(".qkv_proj", ".v_proj", "v"),
91+
(".gate_up_proj", ".gate_proj", 0),
92+
(".gate_up_proj", ".up_proj", 1),
93+
]
94+
params_dict = dict(self.named_parameters())
95+
loaded_params: set[str] = set()
96+
for name, loaded_weight in weights:
97+
name = name.removeprefix("model.")
98+
for param_name, weight_name, shard_id in stacked_params_mapping:
99+
if weight_name not in name:
100+
continue
101+
name = name.replace(weight_name, param_name)
102+
param = params_dict[name]
103+
weight_loader = param.weight_loader
104+
weight_loader(param, loaded_weight, shard_id)
105+
break
106+
else:
107+
# if PP disabled then draft will share embed with target
108+
if get_pp_group().world_size == 1 and \
109+
"embed_tokens." in name:
110+
continue
111+
param = params_dict[name]
112+
weight_loader = getattr(param, "weight_loader",
113+
default_weight_loader)
114+
weight_loader(param, loaded_weight)
115+
loaded_params.add(name)
116+
for name in params_dict:
117+
# if PP disabled then draft will share embed with target
118+
if get_pp_group().world_size == 1 and \
119+
"embed_tokens." in name:
120+
continue
121+
assert name in loaded_params, f"{name} is not loaded!"
122+
return loaded_params
123+
124+
def validate_and_update_config(
125+
self,
126+
start_layer_id: int,
127+
quant_config: Optional[QuantizationConfig] = None) -> None:
128+
# yoco and moe is not supported by draft model yet
129+
assert self.config.yoco_global_kv_layer is None
130+
assert self.config.yoco_local_kv_layer is None
131+
assert len(self.config.moe_layers) == 0
132+
# draft model layer index is increased by start_layer_id,
133+
# so we need to pad relevant configs accordingly
134+
self.config.no_rope_layers = [
135+
0
136+
] * start_layer_id + self.config.no_rope_layers
137+
# currently only TorchAO quantization is supported
138+
if isinstance(quant_config, TorchAOConfig):
139+
140+
def pad_layer_name(layer: str) -> str:
141+
layer_index = extract_layer_index(layer)
142+
return layer.replace(str(layer_index),
143+
str(layer_index + start_layer_id))
144+
145+
quant_config.torchao_config.module_fqn_to_config = {
146+
pad_layer_name(layer): quantization
147+
for layer, quantization in
148+
quant_config.torchao_config.module_fqn_to_config.items()
149+
}
150+
151+
152+
class EagleLlama4ForCausalLM(Llama4ForCausalLM):
153+
154+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
155+
nn.Module.__init__(self)
156+
self.config = vllm_config. \
157+
speculative_config.draft_model_config.hf_config
158+
target_layer_num = vllm_config.model_config.get_num_layers(
159+
vllm_config.parallel_config)
160+
# draft model quantization config may differ from target model
161+
quant_config = VllmConfig.get_quantization_config(
162+
vllm_config.speculative_config.draft_model_config,
163+
vllm_config.load_config)
164+
self.model = LlamaModel(vllm_config=vllm_config,
165+
prefix="model",
166+
start_layer_id=target_layer_num,
167+
quant_config=quant_config)
168+
logit_scale = getattr(self.config, "logit_scale", 1.0)
169+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
170+
scale=logit_scale)
171+
172+
def forward(
173+
self,
174+
input_ids: torch.Tensor,
175+
positions: torch.Tensor,
176+
hidden_states: torch.Tensor,
177+
) -> tuple[torch.Tensor, torch.Tensor]:
178+
return self.model(input_ids, positions, hidden_states)
179+
180+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
181+
loader = AutoWeightsLoader(
182+
self,
183+
# lm_head is tied with target model (Llama4ForCausalLM)
184+
skip_prefixes=(["lm_head."]),
185+
)
186+
187+
model_weights = {}
188+
weights = [
189+
self.permute_qk_weight_for_rotary(name, loaded_weight)
190+
for name, loaded_weight in weights
191+
]
192+
for name, loaded_weight in weights:
193+
if "lm_head" not in name:
194+
name = "model." + name
195+
model_weights[name] = loaded_weight
196+
197+
loader.load_weights(model_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@
239239
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
240240
"EAGLEModel": ("eagle", "EAGLE"),
241241
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
242+
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
242243
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
243244
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
244245
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),

0 commit comments

Comments
 (0)