Skip to content

Commit 875b786

Browse files
rahul-tuliclaude
andcommitted
refactor: Consolidate Eagle speculators weight mapping
- Move SPECULATORS_WEIGHT_MAP to module level to eliminate duplication - Replace duplicate _remap_weight_name methods with single function - Fix line continuation style to use proper parentheses - Streamline weight loading logic while preserving functionality - Remove verbose comments while keeping essential documentation - Preserve original 'fc' naming convention This consolidation improves maintainability and follows vLLM code style conventions while preserving all existing functionality for both Eagle-1 and Eagle-3 speculators models. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: rtuli@redhat.com Co-Authored-By: Claude <noreply@anthropic.com>
1 parent e6027f1 commit 875b786

File tree

1 file changed

+25
-54
lines changed

1 file changed

+25
-54
lines changed

vllm/model_executor/models/llama_eagle.py

Lines changed: 25 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@
2323

2424
logger = init_logger(__name__)
2525

26+
# Weight name mapping for speculators format compatibility
27+
SPECULATORS_WEIGHT_MAP = {
28+
"fusion_fc.weight": "fc.weight",
29+
"fusion_fc.bias": "fc.bias",
30+
"embedding_layernorm.weight": "embedding_layernorm.weight",
31+
"pre_lm_head_layernorm.weight": "hidden_states_layernorm.weight",
32+
}
33+
34+
35+
def remap_speculators_weight_name(name: str) -> str | None:
36+
"""Remap speculators format weight names to vLLM names."""
37+
if name in SPECULATORS_WEIGHT_MAP:
38+
return SPECULATORS_WEIGHT_MAP[name]
39+
elif name.startswith("transformer."):
40+
return None
41+
return name
42+
2643

2744
class LlamaDecoderLayer(LlamaDecoderLayer):
2845

@@ -55,14 +72,6 @@ class LlamaModel(nn.Module):
5572
(2 * hidden_size) and projects them back to hidden_size for processing
5673
through the transformer layers.
5774
"""
58-
59-
# Weight name mapping for speculators format compatibility
60-
SPECULATORS_WEIGHT_MAP = {
61-
"fusion_fc.weight": "projection_layer.weight",
62-
"fusion_fc.bias": "projection_layer.bias",
63-
"embedding_layernorm.weight": "embedding_layernorm.weight",
64-
"pre_lm_head_layernorm.weight": "hidden_states_layernorm.weight",
65-
}
6675

6776
def __init__(
6877
self,
@@ -72,8 +81,7 @@ def __init__(
7281
start_layer_id: int = 0,
7382
) -> None:
7483
super().__init__()
75-
self.config = vllm_config. \
76-
speculative_config.draft_model_config.hf_config
84+
self.config = vllm_config.speculative_config.draft_model_config.hf_config
7785
self.vocab_size = self.config.vocab_size
7886

7987
self.embed_tokens = VocabParallelEmbedding(
@@ -91,9 +99,9 @@ def __init__(
9199
])
92100

93101
# Projection layer: combines input embeddings with target hidden states
94-
self.projection_layer = torch.nn.Linear(self.config.hidden_size * 2,
95-
self.config.hidden_size,
96-
bias=False)
102+
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
103+
self.config.hidden_size,
104+
bias=False)
97105

98106
# Support for additional layernorms (HASS variant)
99107
# HASS adds layernorms to input embeddings and hidden states for better
@@ -134,7 +142,7 @@ def forward(
134142

135143
# Project concatenated embeddings and hidden states
136144
# This combines information from both the input tokens and target model
137-
hidden_states = self.projection_layer(
145+
hidden_states = self.fc(
138146
torch.cat((input_embeds, hidden_states), dim=-1))
139147

140148
# Process through transformer layers
@@ -148,23 +156,6 @@ def forward(
148156
hidden_states = hidden_states + residual
149157
return hidden_states, hidden_states
150158

151-
def _remap_weight_name(self, name: str) -> str | None:
152-
"""
153-
Remap speculators format weight names to vLLM names.
154-
155-
Args:
156-
name: Original weight name from the checkpoint
157-
158-
Returns:
159-
Remapped weight name, or None if the weight should be skipped
160-
"""
161-
if name in self.SPECULATORS_WEIGHT_MAP:
162-
return self.SPECULATORS_WEIGHT_MAP[name]
163-
elif name.startswith("transformer."):
164-
# Skip transformer weights - they're loaded separately by the target model
165-
return None
166-
return name
167-
168159
def load_weights(self, weights: Iterable[tuple[str,
169160
torch.Tensor]]) -> set[str]:
170161
"""
@@ -192,8 +183,7 @@ def load_weights(self, weights: Iterable[tuple[str,
192183
loaded_params: set[str] = set()
193184

194185
for name, loaded_weight in weights:
195-
# Remap weight names for speculators compatibility
196-
remapped_name = self._remap_weight_name(name)
186+
remapped_name = remap_speculators_weight_name(name)
197187
if remapped_name is None:
198188
continue
199189
name = remapped_name
@@ -252,8 +242,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
252242

253243
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
254244
nn.Module.__init__(self)
255-
self.config = vllm_config. \
256-
speculative_config.draft_model_config.hf_config
245+
self.config = vllm_config.speculative_config.draft_model_config.hf_config
257246
target_layer_num = vllm_config.model_config.get_num_layers(
258247
vllm_config.parallel_config)
259248
self.model = LlamaModel(vllm_config=vllm_config,
@@ -283,23 +272,6 @@ def forward(
283272
"""
284273
return self.model(input_ids, positions, hidden_states)
285274

286-
def _remap_weight_name(self, name: str) -> str | None:
287-
"""
288-
Remap speculators format weight names to vLLM names.
289-
290-
Args:
291-
name: Original weight name from the checkpoint
292-
293-
Returns:
294-
Remapped weight name, or None if the weight should be skipped
295-
"""
296-
if name in self.SPECULATORS_WEIGHT_MAP:
297-
return self.SPECULATORS_WEIGHT_MAP[name]
298-
elif name.startswith("transformer."):
299-
# Skip transformer weights - they're loaded separately by the target model
300-
return None
301-
return name
302-
303275
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
304276
"""
305277
Load model weights with support for speculators format.
@@ -317,8 +289,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
317289

318290
model_weights = {}
319291
for name, loaded_weight in weights:
320-
# Remap weight names for speculators compatibility
321-
remapped_name = self._remap_weight_name(name)
292+
remapped_name = remap_speculators_weight_name(name)
322293
if remapped_name is None:
323294
continue
324295
name = remapped_name

0 commit comments

Comments
 (0)