Skip to content

Commit eef118e

Browse files
rahul-tuliclaude
andcommitted
fix: Add HASS Eagle layernorm support for V1 engine
- Add RMSNorm import and support for enorm/hnorm in llama_eagle.py - Apply layernorms in forward pass when add_para_norm is enabled - Handle speculators weight remapping in EagleLlamaForCausalLM.load_weights - Fixes HASS Eagle models (nm-testing/hass-llama3.1-8b-layernorms) in V1 engine 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com>
1 parent e9bda92 commit eef118e

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

vllm/model_executor/models/llama_eagle.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.config import VllmConfig
1212
from vllm.distributed.parallel_state import get_pp_group
1313
from vllm.logger import init_logger
14+
from vllm.model_executor.layers.layernorm import RMSNorm
1415
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1516
from vllm.model_executor.layers.vocab_parallel_embedding import (
1617
VocabParallelEmbedding)
@@ -71,6 +72,15 @@ def __init__(
7172
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
7273
self.config.hidden_size,
7374
bias=False)
75+
76+
# Support for additional layernorms (HASS variant)
77+
self.add_para_norm = False
78+
if hasattr(self.config, "add_para_norm") and self.config.add_para_norm:
79+
self.enorm = RMSNorm(self.config.hidden_size,
80+
eps=self.config.rms_norm_eps)
81+
self.hnorm = RMSNorm(self.config.hidden_size,
82+
eps=self.config.rms_norm_eps)
83+
self.add_para_norm = True
7484

7585
def forward(
7686
self,
@@ -79,6 +89,12 @@ def forward(
7989
hidden_states: torch.Tensor,
8090
) -> tuple[torch.Tensor, torch.Tensor]:
8191
input_embeds = self.embed_tokens(input_ids)
92+
93+
# Apply layernorms if enabled (HASS variant)
94+
if self.add_para_norm:
95+
input_embeds = self.enorm(input_embeds)
96+
hidden_states = self.hnorm(hidden_states)
97+
8298
hidden_states = self.fc(
8399
torch.cat((input_embeds, hidden_states), dim=-1))
84100
residual = None
@@ -177,8 +193,23 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
177193
skip_prefixes=None,
178194
)
179195

196+
# Support for speculators format weights
197+
speculators_name_map = {
198+
"fusion_fc.weight": "fc.weight",
199+
"fusion_fc.bias": "fc.bias",
200+
"embedding_layernorm.weight": "enorm.weight",
201+
"pre_lm_head_layernorm.weight": "hnorm.weight",
202+
}
203+
180204
model_weights = {}
181205
for name, loaded_weight in weights:
206+
# Handle speculators format weight names
207+
if name in speculators_name_map:
208+
name = speculators_name_map[name]
209+
elif name.startswith("transformer."):
210+
# Skip transformer weights - they're loaded separately
211+
continue
212+
182213
if "lm_head" not in name:
183214
name = "model." + name
184215
model_weights[name] = loaded_weight

0 commit comments

Comments
 (0)