23
23
24
24
logger = init_logger (__name__ )
25
25
26
+ # Weight name mapping for speculators format compatibility
27
+ SPECULATORS_WEIGHT_MAP = {
28
+ "fusion_fc.weight" : "fc.weight" ,
29
+ "fusion_fc.bias" : "fc.bias" ,
30
+ "embedding_layernorm.weight" : "embedding_layernorm.weight" ,
31
+ "pre_lm_head_layernorm.weight" : "hidden_states_layernorm.weight" ,
32
+ }
33
+
34
+
35
+ def remap_speculators_weight_name (name : str ) -> str | None :
36
+ """Remap speculators format weight names to vLLM names."""
37
+ if name in SPECULATORS_WEIGHT_MAP :
38
+ return SPECULATORS_WEIGHT_MAP [name ]
39
+ elif name .startswith ("transformer." ):
40
+ return None
41
+ return name
42
+
26
43
27
44
class LlamaDecoderLayer (LlamaDecoderLayer ):
28
45
@@ -55,14 +72,6 @@ class LlamaModel(nn.Module):
55
72
(2 * hidden_size) and projects them back to hidden_size for processing
56
73
through the transformer layers.
57
74
"""
58
-
59
- # Weight name mapping for speculators format compatibility
60
- SPECULATORS_WEIGHT_MAP = {
61
- "fusion_fc.weight" : "projection_layer.weight" ,
62
- "fusion_fc.bias" : "projection_layer.bias" ,
63
- "embedding_layernorm.weight" : "embedding_layernorm.weight" ,
64
- "pre_lm_head_layernorm.weight" : "hidden_states_layernorm.weight" ,
65
- }
66
75
67
76
def __init__ (
68
77
self ,
@@ -72,8 +81,7 @@ def __init__(
72
81
start_layer_id : int = 0 ,
73
82
) -> None :
74
83
super ().__init__ ()
75
- self .config = vllm_config . \
76
- speculative_config .draft_model_config .hf_config
84
+ self .config = vllm_config .speculative_config .draft_model_config .hf_config
77
85
self .vocab_size = self .config .vocab_size
78
86
79
87
self .embed_tokens = VocabParallelEmbedding (
@@ -91,9 +99,9 @@ def __init__(
91
99
])
92
100
93
101
# Projection layer: combines input embeddings with target hidden states
94
- self .projection_layer = torch .nn .Linear (self .config .hidden_size * 2 ,
95
- self .config .hidden_size ,
96
- bias = False )
102
+ self .fc = torch .nn .Linear (self .config .hidden_size * 2 ,
103
+ self .config .hidden_size ,
104
+ bias = False )
97
105
98
106
# Support for additional layernorms (HASS variant)
99
107
# HASS adds layernorms to input embeddings and hidden states for better
@@ -134,7 +142,7 @@ def forward(
134
142
135
143
# Project concatenated embeddings and hidden states
136
144
# This combines information from both the input tokens and target model
137
- hidden_states = self .projection_layer (
145
+ hidden_states = self .fc (
138
146
torch .cat ((input_embeds , hidden_states ), dim = - 1 ))
139
147
140
148
# Process through transformer layers
@@ -148,23 +156,6 @@ def forward(
148
156
hidden_states = hidden_states + residual
149
157
return hidden_states , hidden_states
150
158
151
- def _remap_weight_name (self , name : str ) -> str | None :
152
- """
153
- Remap speculators format weight names to vLLM names.
154
-
155
- Args:
156
- name: Original weight name from the checkpoint
157
-
158
- Returns:
159
- Remapped weight name, or None if the weight should be skipped
160
- """
161
- if name in self .SPECULATORS_WEIGHT_MAP :
162
- return self .SPECULATORS_WEIGHT_MAP [name ]
163
- elif name .startswith ("transformer." ):
164
- # Skip transformer weights - they're loaded separately by the target model
165
- return None
166
- return name
167
-
168
159
def load_weights (self , weights : Iterable [tuple [str ,
169
160
torch .Tensor ]]) -> set [str ]:
170
161
"""
@@ -192,8 +183,7 @@ def load_weights(self, weights: Iterable[tuple[str,
192
183
loaded_params : set [str ] = set ()
193
184
194
185
for name , loaded_weight in weights :
195
- # Remap weight names for speculators compatibility
196
- remapped_name = self ._remap_weight_name (name )
186
+ remapped_name = remap_speculators_weight_name (name )
197
187
if remapped_name is None :
198
188
continue
199
189
name = remapped_name
@@ -252,8 +242,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
252
242
253
243
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
254
244
nn .Module .__init__ (self )
255
- self .config = vllm_config . \
256
- speculative_config .draft_model_config .hf_config
245
+ self .config = vllm_config .speculative_config .draft_model_config .hf_config
257
246
target_layer_num = vllm_config .model_config .get_num_layers (
258
247
vllm_config .parallel_config )
259
248
self .model = LlamaModel (vllm_config = vllm_config ,
@@ -283,23 +272,6 @@ def forward(
283
272
"""
284
273
return self .model (input_ids , positions , hidden_states )
285
274
286
- def _remap_weight_name (self , name : str ) -> str | None :
287
- """
288
- Remap speculators format weight names to vLLM names.
289
-
290
- Args:
291
- name: Original weight name from the checkpoint
292
-
293
- Returns:
294
- Remapped weight name, or None if the weight should be skipped
295
- """
296
- if name in self .SPECULATORS_WEIGHT_MAP :
297
- return self .SPECULATORS_WEIGHT_MAP [name ]
298
- elif name .startswith ("transformer." ):
299
- # Skip transformer weights - they're loaded separately by the target model
300
- return None
301
- return name
302
-
303
275
def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
304
276
"""
305
277
Load model weights with support for speculators format.
@@ -317,8 +289,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
317
289
318
290
model_weights = {}
319
291
for name , loaded_weight in weights :
320
- # Remap weight names for speculators compatibility
321
- remapped_name = self ._remap_weight_name (name )
292
+ remapped_name = remap_speculators_weight_name (name )
322
293
if remapped_name is None :
323
294
continue
324
295
name = remapped_name
0 commit comments