Skip to content

Commit 3d2f0f0

Browse files
committed
feat: Add support for Eagle models in speculators format
- Add weight name mapping for speculators format compatibility - Support HASS variant with additional layernorms - Handle both Eagle-1 and Eagle-3 configurations - Maintain backward compatibility with existing Eagle models This enables using Eagle draft models packaged with the speculators library directly in vLLM for speculative decoding.
1 parent 875b786 commit 3d2f0f0

File tree

3 files changed

+63
-220
lines changed

3 files changed

+63
-220
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,5 @@ shellcheck*/
202202

203203
# Ignore moe/marlin_moe gen code
204204
csrc/moe/marlin_moe_wna16/kernel_*
205+
local/
206+
*.patch

vllm/model_executor/models/llama_eagle.py

Lines changed: 21 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Iterable
5+
from typing import Optional
56

67
import torch
78
import torch.nn as nn
@@ -23,20 +24,23 @@
2324

2425
logger = init_logger(__name__)
2526

26-
# Weight name mapping for speculators format compatibility
27+
# Map speculators weight names to vLLM names
2728
SPECULATORS_WEIGHT_MAP = {
2829
"fusion_fc.weight": "fc.weight",
2930
"fusion_fc.bias": "fc.bias",
30-
"embedding_layernorm.weight": "embedding_layernorm.weight",
3131
"pre_lm_head_layernorm.weight": "hidden_states_layernorm.weight",
3232
}
3333

3434

35-
def remap_speculators_weight_name(name: str) -> str | None:
36-
"""Remap speculators format weight names to vLLM names."""
35+
def remap_speculators_weight_name(name: str) -> Optional[str]:
36+
"""Remap speculators format weight names to vLLM names.
37+
38+
Returns None for transformer weights that should be skipped.
39+
"""
3740
if name in SPECULATORS_WEIGHT_MAP:
3841
return SPECULATORS_WEIGHT_MAP[name]
3942
elif name.startswith("transformer."):
43+
# Skip transformer weights - they're handled separately
4044
return None
4145
return name
4246

@@ -60,18 +64,6 @@ def __init__(
6064

6165
@support_torch_compile
6266
class LlamaModel(nn.Module):
63-
"""
64-
Eagle draft model based on Llama architecture with projection layer.
65-
66-
This model extends the standard Llama architecture for Eagle speculative decoding
67-
by adding a projection layer that combines input embeddings with hidden states
68-
from the target model. It also supports HASS (Hierarchical Aggregation for
69-
Sequence Sketching) variants that include additional layernorm layers.
70-
71-
The projection layer takes concatenated input embeddings and hidden states
72-
(2 * hidden_size) and projects them back to hidden_size for processing
73-
through the transformer layers.
74-
"""
7567

7668
def __init__(
7769
self,
@@ -81,7 +73,8 @@ def __init__(
8173
start_layer_id: int = 0,
8274
) -> None:
8375
super().__init__()
84-
self.config = vllm_config.speculative_config.draft_model_config.hf_config
76+
self.config = vllm_config. \
77+
speculative_config.draft_model_config.hf_config
8578
self.vocab_size = self.config.vocab_size
8679

8780
self.embed_tokens = VocabParallelEmbedding(
@@ -97,55 +90,33 @@ def __init__(
9790
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
9891
) for i in range(self.config.num_hidden_layers)
9992
])
100-
101-
# Projection layer: combines input embeddings with target hidden states
10293
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
10394
self.config.hidden_size,
10495
bias=False)
10596

106-
# Support for additional layernorms (HASS variant)
107-
# HASS adds layernorms to input embeddings and hidden states for better
108-
# representation alignment between draft and target models
109-
self.has_embedding_layernorms = False
110-
if hasattr(self.config, "add_para_norm") and self.config.add_para_norm:
97+
# HASS variant support
98+
self.has_embedding_layernorms = getattr(self.config, "add_para_norm", False)
99+
if self.has_embedding_layernorms:
111100
self.embedding_layernorm = RMSNorm(self.config.hidden_size,
112-
eps=self.config.rms_norm_eps)
101+
eps=self.config.rms_norm_eps)
113102
self.hidden_states_layernorm = RMSNorm(self.config.hidden_size,
114-
eps=self.config.rms_norm_eps)
115-
self.has_embedding_layernorms = True
103+
eps=self.config.rms_norm_eps)
116104

117105
def forward(
118106
self,
119107
input_ids: torch.Tensor,
120108
positions: torch.Tensor,
121109
hidden_states: torch.Tensor,
122110
) -> tuple[torch.Tensor, torch.Tensor]:
123-
"""
124-
Forward pass through the Eagle draft model.
125-
126-
Args:
127-
input_ids: Input token IDs for the draft model
128-
positions: Position indices for the tokens
129-
hidden_states: Hidden states from the target model at the same positions
130-
131-
Returns:
132-
Tuple of (output_hidden_states, output_hidden_states) for compatibility
133-
"""
134111
input_embeds = self.embed_tokens(input_ids)
135112

136-
# Apply layernorms if enabled (HASS variant)
137-
# HASS normalizes both input embeddings and target hidden states
138-
# before combining them to improve alignment
113+
# Apply HASS normalization if enabled
139114
if self.has_embedding_layernorms:
140115
input_embeds = self.embedding_layernorm(input_embeds)
141116
hidden_states = self.hidden_states_layernorm(hidden_states)
142117

143-
# Project concatenated embeddings and hidden states
144-
# This combines information from both the input tokens and target model
145118
hidden_states = self.fc(
146119
torch.cat((input_embeds, hidden_states), dim=-1))
147-
148-
# Process through transformer layers
149120
residual = None
150121
for layer in self.layers:
151122
hidden_states, residual = layer(
@@ -158,19 +129,6 @@ def forward(
158129

159130
def load_weights(self, weights: Iterable[tuple[str,
160131
torch.Tensor]]) -> set[str]:
161-
"""
162-
Load model weights with support for speculators format.
163-
164-
This method handles weight name mapping between speculators format
165-
and vLLM's expected naming convention, ensuring compatibility
166-
with both standard Eagle models and speculators-packaged models.
167-
168-
Args:
169-
weights: Iterable of (weight_name, weight_tensor) pairs
170-
171-
Returns:
172-
Set of parameter names that were successfully loaded
173-
"""
174132
stacked_params_mapping = [
175133
# (param_name, shard_name, shard_id)
176134
(".qkv_proj", ".q_proj", "q"),
@@ -181,14 +139,12 @@ def load_weights(self, weights: Iterable[tuple[str,
181139
]
182140
params_dict = dict(self.named_parameters())
183141
loaded_params: set[str] = set()
184-
185142
for name, loaded_weight in weights:
186143
remapped_name = remap_speculators_weight_name(name)
187144
if remapped_name is None:
188145
continue
189146
name = remapped_name
190147

191-
# Handle stacked parameters (attention and MLP projections)
192148
for param_name, weight_name, shard_id in stacked_params_mapping:
193149
if weight_name not in name:
194150
continue
@@ -198,8 +154,8 @@ def load_weights(self, weights: Iterable[tuple[str,
198154
weight_loader(param, loaded_weight, shard_id)
199155
break
200156
else:
201-
# Skip embedding weights if pipeline parallelism is disabled
202-
# In this case, draft model shares embeddings with target model
157+
158+
# if PP disabled then draft will share embed with target
203159
if get_pp_group().world_size == 1 and \
204160
"embed_tokens." in name:
205161
continue
@@ -217,32 +173,11 @@ def load_weights(self, weights: Iterable[tuple[str,
217173

218174

219175
class EagleLlamaForCausalLM(LlamaForCausalLM):
220-
"""
221-
Eagle draft model for causal language modeling.
222-
223-
This class implements the Eagle draft model architecture for speculative
224-
decoding with Llama-based models. It consists of:
225-
1. A subset of transformer layers (starting after the target model layers)
226-
2. A projection layer that combines input embeddings with target hidden states
227-
3. Optional layernorms for HASS variant
228-
4. Logits processing for token generation
229-
230-
The model generates draft tokens by processing the combination of input
231-
embeddings and hidden states from the target model, enabling faster
232-
speculative decoding.
233-
"""
234-
235-
# Weight name mapping for speculators format compatibility
236-
SPECULATORS_WEIGHT_MAP = {
237-
"fusion_fc.weight": "projection_layer.weight",
238-
"fusion_fc.bias": "projection_layer.bias",
239-
"embedding_layernorm.weight": "embedding_layernorm.weight",
240-
"pre_lm_head_layernorm.weight": "hidden_states_layernorm.weight",
241-
}
242176

243177
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
244178
nn.Module.__init__(self)
245-
self.config = vllm_config.speculative_config.draft_model_config.hf_config
179+
self.config = vllm_config. \
180+
speculative_config.draft_model_config.hf_config
246181
target_layer_num = vllm_config.model_config.get_num_layers(
247182
vllm_config.parallel_config)
248183
self.model = LlamaModel(vllm_config=vllm_config,
@@ -259,29 +194,9 @@ def forward(
259194
positions: torch.Tensor,
260195
hidden_states: torch.Tensor,
261196
) -> tuple[torch.Tensor, torch.Tensor]:
262-
"""
263-
Forward pass through the Eagle draft model.
264-
265-
Args:
266-
input_ids: Input token IDs for the draft model
267-
positions: Position indices for the tokens
268-
hidden_states: Hidden states from the target model
269-
270-
Returns:
271-
Tuple of (output_hidden_states, output_hidden_states) for compatibility
272-
"""
273197
return self.model(input_ids, positions, hidden_states)
274198

275199
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
276-
"""
277-
Load model weights with support for speculators format.
278-
279-
This method handles weight name mapping between speculators format
280-
and vLLM's expected naming convention.
281-
282-
Args:
283-
weights: Iterable of (weight_name, weight_tensor) pairs
284-
"""
285200
loader = AutoWeightsLoader(
286201
self,
287202
skip_prefixes=None,
@@ -293,8 +208,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
293208
if remapped_name is None:
294209
continue
295210
name = remapped_name
296-
297-
# Add model prefix for non-lm_head weights
211+
298212
if "lm_head" not in name:
299213
name = "model." + name
300214
model_weights[name] = loaded_weight

0 commit comments

Comments
 (0)