@@ -52,11 +52,6 @@ def __init__(
52
52
quant_config : Optional [QuantizationConfig ] = None ,
53
53
) -> None :
54
54
super ().__init__ ()
55
- self .embed_tokens = VocabParallelEmbedding (
56
- config .vocab_size ,
57
- config .hidden_size ,
58
- )
59
-
60
55
self .enorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
61
56
self .hnorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
62
57
self .eh_proj = nn .Linear (config .hidden_size * 2 ,
@@ -74,8 +69,6 @@ def forward(
74
69
inputs_embeds : Optional [torch .Tensor ] = None ,
75
70
spec_step_index : int = 0 ,
76
71
) -> torch .Tensor :
77
- if inputs_embeds is None :
78
- inputs_embeds = self .embed_tokens (input_ids )
79
72
assert inputs_embeds is not None
80
73
# masking inputs at position 0, as not needed by MTP
81
74
inputs_embeds [positions == 0 ] = 0
@@ -112,7 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
112
105
for idx in range (self .mtp_start_layer_idx ,
113
106
self .mtp_start_layer_idx + self .num_mtp_layers )
114
107
})
115
-
108
+ self .embed_tokens = VocabParallelEmbedding (
109
+ config .vocab_size ,
110
+ config .hidden_size ,
111
+ )
116
112
self .logits_processor = LogitsProcessor (config .vocab_size )
117
113
118
114
def forward (
@@ -123,6 +119,8 @@ def forward(
123
119
inputs_embeds : Optional [torch .Tensor ] = None ,
124
120
spec_step_idx : int = 0 ,
125
121
) -> torch .Tensor :
122
+ if inputs_embeds is None :
123
+ inputs_embeds = self .embed_tokens (input_ids )
126
124
current_step_idx = (spec_step_idx % self .num_mtp_layers )
127
125
return self .layers [str (self .mtp_start_layer_idx + current_step_idx )](
128
126
input_ids ,
@@ -242,6 +240,12 @@ def load_weights(self, weights: Iterable[tuple[str,
242
240
if name .endswith (".bias" ) and name not in params_dict :
243
241
continue
244
242
243
+ # According to DeepSeek-V3 Technical Report, MTP modules
244
+ # shares embedding layer. We only load the first weights.
245
+ if (spec_layer != self .model .mtp_start_layer_idx
246
+ and ".layers" not in name ):
247
+ continue
248
+
245
249
param = params_dict [name ]
246
250
weight_loader = getattr (param , "weight_loader" ,
247
251
default_weight_loader )
@@ -253,17 +257,25 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
253
257
"""
254
258
Rewrite the weight name to match the format of the original model.
255
259
Add .mtp_block for modules in transformer layer block for spec layer
260
+ and rename shared layer weights to be top level.
256
261
"""
257
262
spec_layer_weight_names = [
258
263
"embed_tokens" , "enorm" , "hnorm" , "eh_proj" , "shared_head"
259
264
]
265
+ shared_weight_names = ["embed_tokens" ]
260
266
spec_layer_weight = False
267
+ shared_weight = False
261
268
for weight_name in spec_layer_weight_names :
262
269
if weight_name in name :
263
270
spec_layer_weight = True
271
+ if weight_name in shared_weight_names :
272
+ shared_weight = True
264
273
break
265
274
if not spec_layer_weight :
266
275
# treat rest weights as weights for transformer layer block
267
276
name = name .replace (f"model.layers.{ spec_layer } ." ,
268
277
f"model.layers.{ spec_layer } .mtp_block." )
278
+ elif shared_weight :
279
+ # treat shared weights as top level weights
280
+ name = name .replace (f"model.layers.{ spec_layer } ." , "model." )
269
281
return name
0 commit comments