Skip to content

Commit 781d056

Browse files
[Feature] Enhance EAGLE Architecture with Proper RMS Norms (#14990)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 5aefd6a commit 781d056

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

vllm/config.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -800,10 +800,18 @@ def get_hidden_size(self) -> int:
800800

801801
@property
802802
def is_deepseek_mla(self) -> bool:
803-
return (hasattr(self.hf_text_config, "model_type")) \
804-
and (self.hf_text_config.model_type in \
805-
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
806-
and (self.hf_text_config.kv_lora_rank is not None)
803+
if not hasattr(self.hf_text_config, "model_type"):
804+
return False
805+
elif self.hf_text_config.model_type in \
806+
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
807+
return self.hf_text_config.kv_lora_rank is not None
808+
elif self.hf_text_config.model_type == 'eagle':
809+
# if the model is an EAGLE module, check for the
810+
# underlying architecture
811+
return self.hf_text_config.model.model_type in \
812+
('deepseek_v2', 'deepseek_v3') \
813+
and self.hf_text_config.kv_lora_rank is not None
814+
return False
807815

808816
def get_head_size(self) -> int:
809817
# TODO remove hard code

vllm/model_executor/models/eagle.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.config import VllmConfig
99
from vllm.logger import init_logger
10+
from vllm.model_executor.layers.layernorm import RMSNorm
1011
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1112
from vllm.model_executor.layers.sampler import SamplerOutput
1213
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -59,7 +60,15 @@ class EAGLE(nn.Module):
5960
truncated_vocab_size < vocab_size. To use this technique, one has to find
6061
the top-k most frequent tokens in target dataset and add that as a tensor
6162
in the draft checkpoint (using key token_map). Also, the draft config
62-
needs to have truncated_vocab_size (=k) as an attribute."""
63+
needs to have truncated_vocab_size (=k) as an attribute.
64+
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
65+
module with regards to the use of additional RMS norms. The original
66+
EAGLE architecture 1) skips the pre-attention norm in its first
67+
transformer block, and 2) skips the final output norm, both of which we
68+
found to be suboptimal. We also add the support for separate norms
69+
applying to both the token embedding and hidden states before projection
70+
as in DeepSeek MTP, which we found to improve performance as well.
71+
"""
6372

6473
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
6574
super().__init__()
@@ -81,9 +90,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
8190
# While weights and biases are generally not needed,
8291
# they are retained here to support certain unit tests
8392
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
84-
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
85-
weight=self.model.model.layers[0].input_layernorm.weight)
86-
self.model.model.norm = DummyOutputNorm()
93+
if not hasattr(self.config.model,
94+
"skip_prenorm") or self.config.model.skip_prenorm:
95+
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
96+
weight=self.model.model.layers[0].input_layernorm.weight)
97+
98+
if not hasattr(
99+
self.config.model,
100+
"skip_output_norm") or self.config.model.skip_output_norm:
101+
self.model.model.norm = DummyOutputNorm()
102+
103+
self.add_para_norm = False
104+
if hasattr(self.config.model,
105+
"add_para_norm") and self.config.model.add_para_norm:
106+
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
107+
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
108+
self.add_para_norm = True
87109

88110
self.orig_vocab_size = config.vocab_size
89111
self.truncated_vocab_size = config.truncated_vocab_size
@@ -128,8 +150,17 @@ def forward(
128150
if inputs_embeds is None:
129151
inputs_embeds = self.get_input_embeddings(input_ids)
130152

131-
inputs_embeds = self.fc(
132-
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
153+
if self.add_para_norm:
154+
inputs_embeds = torch.cat([
155+
self.enorm(inputs_embeds),
156+
self.hnorm(previous_hidden_states)
157+
],
158+
dim=-1)
159+
else:
160+
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
161+
dim=-1)
162+
163+
inputs_embeds = self.fc(inputs_embeds)
133164

134165
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
135166

@@ -190,6 +221,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
190221
else:
191222
logger.warning_once("Found bias in the loaded weights but "
192223
"the model config doesn't have bias.")
224+
elif name.startswith("enorm.weight"):
225+
weight_loader = getattr(self.enorm.weight, "weight_loader",
226+
default_weight_loader)
227+
weight_loader(self.enorm.weight, loaded_weight)
228+
elif name.startswith("hnorm.weight"):
229+
weight_loader = getattr(self.hnorm.weight, "weight_loader",
230+
default_weight_loader)
231+
weight_loader(self.hnorm.weight, loaded_weight)
193232
elif name.startswith("model.lm_head.") or name.startswith(
194233
"model.model."):
195234
model_weights[name.split("model.", 1)[-1]] = loaded_weight

vllm/transformers_utils/configs/eagle.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from transformers import AutoConfig, PretrainedConfig
77

8+
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
9+
810

911
class EAGLEConfig(PretrainedConfig):
1012
model_type = "eagle"
@@ -14,8 +16,17 @@ def __init__(self,
1416
truncated_vocab_size: Optional[int] = None,
1517
**kwargs):
1618

17-
model_config = None if model is None else (AutoConfig.for_model(
18-
**model) if isinstance(model, dict) else model)
19+
model_config: Union[PretrainedConfig, DeepseekV2Config, None]
20+
if isinstance(model, dict):
21+
archs = model.get("architectures", [])
22+
target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]
23+
if any(target_arch in archs for target_arch in target_archs):
24+
# AutoConfig does not support DeepSeek MoE models yet
25+
model_config = DeepseekV2Config(**model)
26+
else:
27+
model_config = AutoConfig.for_model(**model)
28+
else:
29+
model_config = model
1930

2031
for k, v in kwargs.items():
2132
if k != "architectures" and k != "model_type" and hasattr(

0 commit comments

Comments
 (0)