@@ -43,6 +43,26 @@ def __init__(
43
43
44
44
@support_torch_compile
45
45
class LlamaModel (nn .Module ):
46
+ """
47
+ Eagle draft model based on Llama architecture with projection layer.
48
+
49
+ This model extends the standard Llama architecture for Eagle speculative decoding
50
+ by adding a projection layer that combines input embeddings with hidden states
51
+ from the target model. It also supports HASS (Hierarchical Aggregation for
52
+ Sequence Sketching) variants that include additional layernorm layers.
53
+
54
+ The projection layer takes concatenated input embeddings and hidden states
55
+ (2 * hidden_size) and projects them back to hidden_size for processing
56
+ through the transformer layers.
57
+ """
58
+
59
+ # Weight name mapping for speculators format compatibility
60
+ SPECULATORS_WEIGHT_MAP = {
61
+ "fusion_fc.weight" : "projection_layer.weight" ,
62
+ "fusion_fc.bias" : "projection_layer.bias" ,
63
+ "embedding_layernorm.weight" : "embedding_layernorm.weight" ,
64
+ "pre_lm_head_layernorm.weight" : "hidden_states_layernorm.weight" ,
65
+ }
46
66
47
67
def __init__ (
48
68
self ,
@@ -69,34 +89,55 @@ def __init__(
69
89
prefix = maybe_prefix (prefix , f"layers.{ i + start_layer_id } " ),
70
90
) for i in range (self .config .num_hidden_layers )
71
91
])
72
- self .fc = torch .nn .Linear (self .config .hidden_size * 2 ,
73
- self .config .hidden_size ,
74
- bias = False )
92
+
93
+ # Projection layer: combines input embeddings with target hidden states
94
+ self .projection_layer = torch .nn .Linear (self .config .hidden_size * 2 ,
95
+ self .config .hidden_size ,
96
+ bias = False )
75
97
76
98
# Support for additional layernorms (HASS variant)
77
- self .add_para_norm = False
99
+ # HASS adds layernorms to input embeddings and hidden states for better
100
+ # representation alignment between draft and target models
101
+ self .has_embedding_layernorms = False
78
102
if hasattr (self .config , "add_para_norm" ) and self .config .add_para_norm :
79
- self .enorm = RMSNorm (self .config .hidden_size ,
80
- eps = self .config .rms_norm_eps )
81
- self .hnorm = RMSNorm (self .config .hidden_size ,
82
- eps = self .config .rms_norm_eps )
83
- self .add_para_norm = True
103
+ self .embedding_layernorm = RMSNorm (self .config .hidden_size ,
104
+ eps = self .config .rms_norm_eps )
105
+ self .hidden_states_layernorm = RMSNorm (self .config .hidden_size ,
106
+ eps = self .config .rms_norm_eps )
107
+ self .has_embedding_layernorms = True
84
108
85
109
def forward (
86
110
self ,
87
111
input_ids : torch .Tensor ,
88
112
positions : torch .Tensor ,
89
113
hidden_states : torch .Tensor ,
90
114
) -> tuple [torch .Tensor , torch .Tensor ]:
115
+ """
116
+ Forward pass through the Eagle draft model.
117
+
118
+ Args:
119
+ input_ids: Input token IDs for the draft model
120
+ positions: Position indices for the tokens
121
+ hidden_states: Hidden states from the target model at the same positions
122
+
123
+ Returns:
124
+ Tuple of (output_hidden_states, output_hidden_states) for compatibility
125
+ """
91
126
input_embeds = self .embed_tokens (input_ids )
92
127
93
128
# Apply layernorms if enabled (HASS variant)
94
- if self .add_para_norm :
95
- input_embeds = self .enorm (input_embeds )
96
- hidden_states = self .hnorm (hidden_states )
129
+ # HASS normalizes both input embeddings and target hidden states
130
+ # before combining them to improve alignment
131
+ if self .has_embedding_layernorms :
132
+ input_embeds = self .embedding_layernorm (input_embeds )
133
+ hidden_states = self .hidden_states_layernorm (hidden_states )
97
134
98
- hidden_states = self .fc (
135
+ # Project concatenated embeddings and hidden states
136
+ # This combines information from both the input tokens and target model
137
+ hidden_states = self .projection_layer (
99
138
torch .cat ((input_embeds , hidden_states ), dim = - 1 ))
139
+
140
+ # Process through transformer layers
100
141
residual = None
101
142
for layer in self .layers :
102
143
hidden_states , residual = layer (
@@ -107,8 +148,38 @@ def forward(
107
148
hidden_states = hidden_states + residual
108
149
return hidden_states , hidden_states
109
150
151
+ def _remap_weight_name (self , name : str ) -> str | None :
152
+ """
153
+ Remap speculators format weight names to vLLM names.
154
+
155
+ Args:
156
+ name: Original weight name from the checkpoint
157
+
158
+ Returns:
159
+ Remapped weight name, or None if the weight should be skipped
160
+ """
161
+ if name in self .SPECULATORS_WEIGHT_MAP :
162
+ return self .SPECULATORS_WEIGHT_MAP [name ]
163
+ elif name .startswith ("transformer." ):
164
+ # Skip transformer weights - they're loaded separately by the target model
165
+ return None
166
+ return name
167
+
110
168
def load_weights (self , weights : Iterable [tuple [str ,
111
169
torch .Tensor ]]) -> set [str ]:
170
+ """
171
+ Load model weights with support for speculators format.
172
+
173
+ This method handles weight name mapping between speculators format
174
+ and vLLM's expected naming convention, ensuring compatibility
175
+ with both standard Eagle models and speculators-packaged models.
176
+
177
+ Args:
178
+ weights: Iterable of (weight_name, weight_tensor) pairs
179
+
180
+ Returns:
181
+ Set of parameter names that were successfully loaded
182
+ """
112
183
stacked_params_mapping = [
113
184
# (param_name, shard_name, shard_id)
114
185
(".qkv_proj" , ".q_proj" , "q" ),
@@ -120,22 +191,14 @@ def load_weights(self, weights: Iterable[tuple[str,
120
191
params_dict = dict (self .named_parameters ())
121
192
loaded_params : set [str ] = set ()
122
193
123
- # Support for speculators format weights
124
- speculators_name_map = {
125
- "fusion_fc.weight" : "fc.weight" ,
126
- "fusion_fc.bias" : "fc.bias" ,
127
- "embedding_layernorm.weight" : "enorm.weight" ,
128
- "pre_lm_head_layernorm.weight" : "hnorm.weight" ,
129
- }
130
-
131
194
for name , loaded_weight in weights :
132
- # Handle speculators format weight names
133
- if name in speculators_name_map :
134
- name = speculators_name_map [name ]
135
- elif name .startswith ("transformer." ):
136
- # Skip transformer weights - they're loaded separately
195
+ # Remap weight names for speculators compatibility
196
+ remapped_name = self ._remap_weight_name (name )
197
+ if remapped_name is None :
137
198
continue
199
+ name = remapped_name
138
200
201
+ # Handle stacked parameters (attention and MLP projections)
139
202
for param_name , weight_name , shard_id in stacked_params_mapping :
140
203
if weight_name not in name :
141
204
continue
@@ -145,8 +208,8 @@ def load_weights(self, weights: Iterable[tuple[str,
145
208
weight_loader (param , loaded_weight , shard_id )
146
209
break
147
210
else :
148
-
149
- # if PP disabled then draft will share embed with target
211
+ # Skip embedding weights if pipeline parallelism is disabled
212
+ # In this case, draft model shares embeddings with target model
150
213
if get_pp_group ().world_size == 1 and \
151
214
"embed_tokens." in name :
152
215
continue
@@ -164,6 +227,28 @@ def load_weights(self, weights: Iterable[tuple[str,
164
227
165
228
166
229
class EagleLlamaForCausalLM (LlamaForCausalLM ):
230
+ """
231
+ Eagle draft model for causal language modeling.
232
+
233
+ This class implements the Eagle draft model architecture for speculative
234
+ decoding with Llama-based models. It consists of:
235
+ 1. A subset of transformer layers (starting after the target model layers)
236
+ 2. A projection layer that combines input embeddings with target hidden states
237
+ 3. Optional layernorms for HASS variant
238
+ 4. Logits processing for token generation
239
+
240
+ The model generates draft tokens by processing the combination of input
241
+ embeddings and hidden states from the target model, enabling faster
242
+ speculative decoding.
243
+ """
244
+
245
+ # Weight name mapping for speculators format compatibility
246
+ SPECULATORS_WEIGHT_MAP = {
247
+ "fusion_fc.weight" : "projection_layer.weight" ,
248
+ "fusion_fc.bias" : "projection_layer.bias" ,
249
+ "embedding_layernorm.weight" : "embedding_layernorm.weight" ,
250
+ "pre_lm_head_layernorm.weight" : "hidden_states_layernorm.weight" ,
251
+ }
167
252
168
253
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
169
254
nn .Module .__init__ (self )
@@ -185,31 +270,60 @@ def forward(
185
270
positions : torch .Tensor ,
186
271
hidden_states : torch .Tensor ,
187
272
) -> tuple [torch .Tensor , torch .Tensor ]:
273
+ """
274
+ Forward pass through the Eagle draft model.
275
+
276
+ Args:
277
+ input_ids: Input token IDs for the draft model
278
+ positions: Position indices for the tokens
279
+ hidden_states: Hidden states from the target model
280
+
281
+ Returns:
282
+ Tuple of (output_hidden_states, output_hidden_states) for compatibility
283
+ """
188
284
return self .model (input_ids , positions , hidden_states )
189
285
286
+ def _remap_weight_name (self , name : str ) -> str | None :
287
+ """
288
+ Remap speculators format weight names to vLLM names.
289
+
290
+ Args:
291
+ name: Original weight name from the checkpoint
292
+
293
+ Returns:
294
+ Remapped weight name, or None if the weight should be skipped
295
+ """
296
+ if name in self .SPECULATORS_WEIGHT_MAP :
297
+ return self .SPECULATORS_WEIGHT_MAP [name ]
298
+ elif name .startswith ("transformer." ):
299
+ # Skip transformer weights - they're loaded separately by the target model
300
+ return None
301
+ return name
302
+
190
303
def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
304
+ """
305
+ Load model weights with support for speculators format.
306
+
307
+ This method handles weight name mapping between speculators format
308
+ and vLLM's expected naming convention.
309
+
310
+ Args:
311
+ weights: Iterable of (weight_name, weight_tensor) pairs
312
+ """
191
313
loader = AutoWeightsLoader (
192
314
self ,
193
315
skip_prefixes = None ,
194
316
)
195
317
196
- # Support for speculators format weights
197
- speculators_name_map = {
198
- "fusion_fc.weight" : "fc.weight" ,
199
- "fusion_fc.bias" : "fc.bias" ,
200
- "embedding_layernorm.weight" : "enorm.weight" ,
201
- "pre_lm_head_layernorm.weight" : "hnorm.weight" ,
202
- }
203
-
204
318
model_weights = {}
205
319
for name , loaded_weight in weights :
206
- # Handle speculators format weight names
207
- if name in speculators_name_map :
208
- name = speculators_name_map [name ]
209
- elif name .startswith ("transformer." ):
210
- # Skip transformer weights - they're loaded separately
320
+ # Remap weight names for speculators compatibility
321
+ remapped_name = self ._remap_weight_name (name )
322
+ if remapped_name is None :
211
323
continue
324
+ name = remapped_name
212
325
326
+ # Add model prefix for non-lm_head weights
213
327
if "lm_head" not in name :
214
328
name = "model." + name
215
329
model_weights [name ] = loaded_weight
0 commit comments