7
7
8
8
from vllm .config import VllmConfig
9
9
from vllm .logger import init_logger
10
+ from vllm .model_executor .layers .layernorm import RMSNorm
10
11
from vllm .model_executor .layers .logits_processor import LogitsProcessor
11
12
from vllm .model_executor .layers .sampler import SamplerOutput
12
13
from vllm .model_executor .layers .vocab_parallel_embedding import (
@@ -59,7 +60,15 @@ class EAGLE(nn.Module):
59
60
truncated_vocab_size < vocab_size. To use this technique, one has to find
60
61
the top-k most frequent tokens in target dataset and add that as a tensor
61
62
in the draft checkpoint (using key token_map). Also, the draft config
62
- needs to have truncated_vocab_size (=k) as an attribute."""
63
+ needs to have truncated_vocab_size (=k) as an attribute.
64
+ 4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
65
+ module with regards to the use of additional RMS norms. The original
66
+ EAGLE architecture 1) skips the pre-attention norm in its first
67
+ transformer block, and 2) skips the final output norm, both of which we
68
+ found to be suboptimal. We also add the support for separate norms
69
+ applying to both the token embedding and hidden states before projection
70
+ as in DeepSeek MTP, which we found to improve performance as well.
71
+ """
63
72
64
73
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
65
74
super ().__init__ ()
@@ -81,9 +90,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
81
90
# While weights and biases are generally not needed,
82
91
# they are retained here to support certain unit tests
83
92
# (e.g., spec_decode/e2e/test_eagle_correctness.py).
84
- self .model .model .layers [0 ].input_layernorm = DummyInputLayerNorm (
85
- weight = self .model .model .layers [0 ].input_layernorm .weight )
86
- self .model .model .norm = DummyOutputNorm ()
93
+ if not hasattr (self .config .model ,
94
+ "skip_prenorm" ) or self .config .model .skip_prenorm :
95
+ self .model .model .layers [0 ].input_layernorm = DummyInputLayerNorm (
96
+ weight = self .model .model .layers [0 ].input_layernorm .weight )
97
+
98
+ if not hasattr (
99
+ self .config .model ,
100
+ "skip_output_norm" ) or self .config .model .skip_output_norm :
101
+ self .model .model .norm = DummyOutputNorm ()
102
+
103
+ self .add_para_norm = False
104
+ if hasattr (self .config .model ,
105
+ "add_para_norm" ) and self .config .model .add_para_norm :
106
+ self .enorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
107
+ self .hnorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
108
+ self .add_para_norm = True
87
109
88
110
self .orig_vocab_size = config .vocab_size
89
111
self .truncated_vocab_size = config .truncated_vocab_size
@@ -128,8 +150,17 @@ def forward(
128
150
if inputs_embeds is None :
129
151
inputs_embeds = self .get_input_embeddings (input_ids )
130
152
131
- inputs_embeds = self .fc (
132
- torch .cat ([inputs_embeds , previous_hidden_states ], dim = - 1 ))
153
+ if self .add_para_norm :
154
+ inputs_embeds = torch .cat ([
155
+ self .enorm (inputs_embeds ),
156
+ self .hnorm (previous_hidden_states )
157
+ ],
158
+ dim = - 1 )
159
+ else :
160
+ inputs_embeds = torch .cat ([inputs_embeds , previous_hidden_states ],
161
+ dim = - 1 )
162
+
163
+ inputs_embeds = self .fc (inputs_embeds )
133
164
134
165
inputs_embeds [positions == 0 ] = 0 # masking inputs at position=0
135
166
@@ -190,6 +221,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
190
221
else :
191
222
logger .warning_once ("Found bias in the loaded weights but "
192
223
"the model config doesn't have bias." )
224
+ elif name .startswith ("enorm.weight" ):
225
+ weight_loader = getattr (self .enorm .weight , "weight_loader" ,
226
+ default_weight_loader )
227
+ weight_loader (self .enorm .weight , loaded_weight )
228
+ elif name .startswith ("hnorm.weight" ):
229
+ weight_loader = getattr (self .hnorm .weight , "weight_loader" ,
230
+ default_weight_loader )
231
+ weight_loader (self .hnorm .weight , loaded_weight )
193
232
elif name .startswith ("model.lm_head." ) or name .startswith (
194
233
"model.model." ):
195
234
model_weights [name .split ("model." , 1 )[- 1 ]] = loaded_weight
0 commit comments