2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
4
from collections .abc import Iterable
5
+ from typing import Optional
5
6
6
7
import torch
7
8
import torch .nn as nn
23
24
24
25
logger = init_logger (__name__ )
25
26
26
- # Weight name mapping for speculators format compatibility
27
+ # Map speculators weight names to vLLM names
27
28
SPECULATORS_WEIGHT_MAP = {
28
29
"fusion_fc.weight" : "fc.weight" ,
29
30
"fusion_fc.bias" : "fc.bias" ,
30
- "embedding_layernorm.weight" : "embedding_layernorm.weight" ,
31
31
"pre_lm_head_layernorm.weight" : "hidden_states_layernorm.weight" ,
32
32
}
33
33
34
34
35
- def remap_speculators_weight_name (name : str ) -> str | None :
36
- """Remap speculators format weight names to vLLM names."""
35
+ def remap_speculators_weight_name (name : str ) -> Optional [str ]:
36
+ """Remap speculators format weight names to vLLM names.
37
+
38
+ Returns None for transformer weights that should be skipped.
39
+ """
37
40
if name in SPECULATORS_WEIGHT_MAP :
38
41
return SPECULATORS_WEIGHT_MAP [name ]
39
42
elif name .startswith ("transformer." ):
43
+ # Skip transformer weights - they're handled separately
40
44
return None
41
45
return name
42
46
@@ -60,18 +64,6 @@ def __init__(
60
64
61
65
@support_torch_compile
62
66
class LlamaModel (nn .Module ):
63
- """
64
- Eagle draft model based on Llama architecture with projection layer.
65
-
66
- This model extends the standard Llama architecture for Eagle speculative decoding
67
- by adding a projection layer that combines input embeddings with hidden states
68
- from the target model. It also supports HASS (Hierarchical Aggregation for
69
- Sequence Sketching) variants that include additional layernorm layers.
70
-
71
- The projection layer takes concatenated input embeddings and hidden states
72
- (2 * hidden_size) and projects them back to hidden_size for processing
73
- through the transformer layers.
74
- """
75
67
76
68
def __init__ (
77
69
self ,
@@ -81,7 +73,8 @@ def __init__(
81
73
start_layer_id : int = 0 ,
82
74
) -> None :
83
75
super ().__init__ ()
84
- self .config = vllm_config .speculative_config .draft_model_config .hf_config
76
+ self .config = vllm_config . \
77
+ speculative_config .draft_model_config .hf_config
85
78
self .vocab_size = self .config .vocab_size
86
79
87
80
self .embed_tokens = VocabParallelEmbedding (
@@ -97,55 +90,33 @@ def __init__(
97
90
prefix = maybe_prefix (prefix , f"layers.{ i + start_layer_id } " ),
98
91
) for i in range (self .config .num_hidden_layers )
99
92
])
100
-
101
- # Projection layer: combines input embeddings with target hidden states
102
93
self .fc = torch .nn .Linear (self .config .hidden_size * 2 ,
103
94
self .config .hidden_size ,
104
95
bias = False )
105
96
106
- # Support for additional layernorms (HASS variant)
107
- # HASS adds layernorms to input embeddings and hidden states for better
108
- # representation alignment between draft and target models
109
- self .has_embedding_layernorms = False
110
- if hasattr (self .config , "add_para_norm" ) and self .config .add_para_norm :
97
+ # HASS variant support
98
+ self .has_embedding_layernorms = getattr (self .config , "add_para_norm" , False )
99
+ if self .has_embedding_layernorms :
111
100
self .embedding_layernorm = RMSNorm (self .config .hidden_size ,
112
- eps = self .config .rms_norm_eps )
101
+ eps = self .config .rms_norm_eps )
113
102
self .hidden_states_layernorm = RMSNorm (self .config .hidden_size ,
114
- eps = self .config .rms_norm_eps )
115
- self .has_embedding_layernorms = True
103
+ eps = self .config .rms_norm_eps )
116
104
117
105
def forward (
118
106
self ,
119
107
input_ids : torch .Tensor ,
120
108
positions : torch .Tensor ,
121
109
hidden_states : torch .Tensor ,
122
110
) -> tuple [torch .Tensor , torch .Tensor ]:
123
- """
124
- Forward pass through the Eagle draft model.
125
-
126
- Args:
127
- input_ids: Input token IDs for the draft model
128
- positions: Position indices for the tokens
129
- hidden_states: Hidden states from the target model at the same positions
130
-
131
- Returns:
132
- Tuple of (output_hidden_states, output_hidden_states) for compatibility
133
- """
134
111
input_embeds = self .embed_tokens (input_ids )
135
112
136
- # Apply layernorms if enabled (HASS variant)
137
- # HASS normalizes both input embeddings and target hidden states
138
- # before combining them to improve alignment
113
+ # Apply HASS normalization if enabled
139
114
if self .has_embedding_layernorms :
140
115
input_embeds = self .embedding_layernorm (input_embeds )
141
116
hidden_states = self .hidden_states_layernorm (hidden_states )
142
117
143
- # Project concatenated embeddings and hidden states
144
- # This combines information from both the input tokens and target model
145
118
hidden_states = self .fc (
146
119
torch .cat ((input_embeds , hidden_states ), dim = - 1 ))
147
-
148
- # Process through transformer layers
149
120
residual = None
150
121
for layer in self .layers :
151
122
hidden_states , residual = layer (
@@ -158,19 +129,6 @@ def forward(
158
129
159
130
def load_weights (self , weights : Iterable [tuple [str ,
160
131
torch .Tensor ]]) -> set [str ]:
161
- """
162
- Load model weights with support for speculators format.
163
-
164
- This method handles weight name mapping between speculators format
165
- and vLLM's expected naming convention, ensuring compatibility
166
- with both standard Eagle models and speculators-packaged models.
167
-
168
- Args:
169
- weights: Iterable of (weight_name, weight_tensor) pairs
170
-
171
- Returns:
172
- Set of parameter names that were successfully loaded
173
- """
174
132
stacked_params_mapping = [
175
133
# (param_name, shard_name, shard_id)
176
134
(".qkv_proj" , ".q_proj" , "q" ),
@@ -181,14 +139,12 @@ def load_weights(self, weights: Iterable[tuple[str,
181
139
]
182
140
params_dict = dict (self .named_parameters ())
183
141
loaded_params : set [str ] = set ()
184
-
185
142
for name , loaded_weight in weights :
186
143
remapped_name = remap_speculators_weight_name (name )
187
144
if remapped_name is None :
188
145
continue
189
146
name = remapped_name
190
147
191
- # Handle stacked parameters (attention and MLP projections)
192
148
for param_name , weight_name , shard_id in stacked_params_mapping :
193
149
if weight_name not in name :
194
150
continue
@@ -198,8 +154,8 @@ def load_weights(self, weights: Iterable[tuple[str,
198
154
weight_loader (param , loaded_weight , shard_id )
199
155
break
200
156
else :
201
- # Skip embedding weights if pipeline parallelism is disabled
202
- # In this case, draft model shares embeddings with target model
157
+
158
+ # if PP disabled then draft will share embed with target
203
159
if get_pp_group ().world_size == 1 and \
204
160
"embed_tokens." in name :
205
161
continue
@@ -217,32 +173,11 @@ def load_weights(self, weights: Iterable[tuple[str,
217
173
218
174
219
175
class EagleLlamaForCausalLM (LlamaForCausalLM ):
220
- """
221
- Eagle draft model for causal language modeling.
222
-
223
- This class implements the Eagle draft model architecture for speculative
224
- decoding with Llama-based models. It consists of:
225
- 1. A subset of transformer layers (starting after the target model layers)
226
- 2. A projection layer that combines input embeddings with target hidden states
227
- 3. Optional layernorms for HASS variant
228
- 4. Logits processing for token generation
229
-
230
- The model generates draft tokens by processing the combination of input
231
- embeddings and hidden states from the target model, enabling faster
232
- speculative decoding.
233
- """
234
-
235
- # Weight name mapping for speculators format compatibility
236
- SPECULATORS_WEIGHT_MAP = {
237
- "fusion_fc.weight" : "projection_layer.weight" ,
238
- "fusion_fc.bias" : "projection_layer.bias" ,
239
- "embedding_layernorm.weight" : "embedding_layernorm.weight" ,
240
- "pre_lm_head_layernorm.weight" : "hidden_states_layernorm.weight" ,
241
- }
242
176
243
177
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
244
178
nn .Module .__init__ (self )
245
- self .config = vllm_config .speculative_config .draft_model_config .hf_config
179
+ self .config = vllm_config . \
180
+ speculative_config .draft_model_config .hf_config
246
181
target_layer_num = vllm_config .model_config .get_num_layers (
247
182
vllm_config .parallel_config )
248
183
self .model = LlamaModel (vllm_config = vllm_config ,
@@ -259,29 +194,9 @@ def forward(
259
194
positions : torch .Tensor ,
260
195
hidden_states : torch .Tensor ,
261
196
) -> tuple [torch .Tensor , torch .Tensor ]:
262
- """
263
- Forward pass through the Eagle draft model.
264
-
265
- Args:
266
- input_ids: Input token IDs for the draft model
267
- positions: Position indices for the tokens
268
- hidden_states: Hidden states from the target model
269
-
270
- Returns:
271
- Tuple of (output_hidden_states, output_hidden_states) for compatibility
272
- """
273
197
return self .model (input_ids , positions , hidden_states )
274
198
275
199
def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
276
- """
277
- Load model weights with support for speculators format.
278
-
279
- This method handles weight name mapping between speculators format
280
- and vLLM's expected naming convention.
281
-
282
- Args:
283
- weights: Iterable of (weight_name, weight_tensor) pairs
284
- """
285
200
loader = AutoWeightsLoader (
286
201
self ,
287
202
skip_prefixes = None ,
@@ -293,8 +208,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
293
208
if remapped_name is None :
294
209
continue
295
210
name = remapped_name
296
-
297
- # Add model prefix for non-lm_head weights
211
+
298
212
if "lm_head" not in name :
299
213
name = "model." + name
300
214
model_weights [name ] = loaded_weight
0 commit comments