@@ -152,17 +152,23 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
152
152
self .v_head_dim = model_args .v_head_dim
153
153
154
154
if self .q_lora_rank == 0 :
155
- self .wq = nn .Linear (self .dim , self .n_heads * self .qk_head_dim )
155
+ self .wq = nn .Linear (self .dim , self .n_heads * self .qk_head_dim , bias = False )
156
156
else :
157
- self .wq_a = nn .Linear (self .dim , self .q_lora_rank )
157
+ self .wq_a = nn .Linear (self .dim , self .q_lora_rank , bias = False )
158
158
self .q_norm = nn .RMSNorm (self .q_lora_rank , eps = model_args .norm_eps )
159
- self .wq_b = nn .Linear (self .q_lora_rank , self .n_heads * self .qk_head_dim )
160
- self .wkv_a = nn .Linear (self .dim , self .kv_lora_rank + self .qk_rope_head_dim )
159
+ self .wq_b = nn .Linear (
160
+ self .q_lora_rank , self .n_heads * self .qk_head_dim , bias = False
161
+ )
162
+ self .wkv_a = nn .Linear (
163
+ self .dim , self .kv_lora_rank + self .qk_rope_head_dim , bias = False
164
+ )
161
165
self .kv_norm = nn .RMSNorm (self .kv_lora_rank , eps = model_args .norm_eps )
162
166
self .wkv_b = nn .Linear (
163
- self .kv_lora_rank , self .n_heads * (self .qk_nope_head_dim + self .v_head_dim )
167
+ self .kv_lora_rank ,
168
+ self .n_heads * (self .qk_nope_head_dim + self .v_head_dim ),
169
+ bias = False ,
164
170
)
165
- self .wo = nn .Linear (self .n_heads * self .v_head_dim , self .dim )
171
+ self .wo = nn .Linear (self .n_heads * self .v_head_dim , self .dim , bias = False )
166
172
self .softmax_scale = self .qk_head_dim ** - 0.5
167
173
168
174
if model_args .max_seq_len > model_args .original_seq_len :
@@ -192,8 +198,8 @@ def forward(
192
198
if self .q_lora_rank == 0 :
193
199
q = self .wq (x ) # (bsz, seqlen, n_heads * qk_head_dim)
194
200
else :
195
- q = self .wq_b ( self . q_norm ( self . wq_a (x )) )
196
-
201
+ q = self .wq_a (x )
202
+ q = self . wq_b ( self . q_norm ( q ))
197
203
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
198
204
# local heads from sizes of q and kv as TP may have sharded them after
199
205
# the above linear ops.
@@ -235,6 +241,24 @@ def forward(
235
241
output = output .view (bsz , seqlen , - 1 ) # (bsz, seqlen, n_heads * v_head_dim)
236
242
return self .wo (output ) # (bsz, seqlen, dim)
237
243
244
+ def init_weights (self , init_std : float ):
245
+ linear_list = [
246
+ self .wkv_a ,
247
+ self .wkv_b ,
248
+ ]
249
+ if self .q_lora_rank > 0 :
250
+ linear_list .extend ([self .wq_a , self .wq_b ])
251
+ else :
252
+ linear_list .append (self .wq )
253
+
254
+ for linear in linear_list :
255
+ nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = 0.02 )
256
+ nn .init .trunc_normal_ (self .wo .weight , mean = 0.0 , std = init_std )
257
+
258
+ self .kv_norm .reset_parameters ()
259
+ if self .q_lora_rank > 0 :
260
+ self .q_norm .reset_parameters ()
261
+
238
262
239
263
class FeedForward (nn .Module ):
240
264
"""
@@ -266,7 +290,7 @@ def __init__(
266
290
def forward (self , x : torch .Tensor ) -> torch .Tensor :
267
291
return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
268
292
269
- def init_weights (self , init_std : float ):
293
+ def init_weights (self , init_std : float = 0.02 ):
270
294
nn .init .trunc_normal_ (self .w1 .weight , mean = 0.0 , std = 0.02 )
271
295
for linear in (self .w2 , self .w3 ):
272
296
nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = init_std )
@@ -283,13 +307,16 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
283
307
self .attention = Attention (model_args )
284
308
self .attention_norm = nn .RMSNorm (model_args .dim , eps = model_args .norm_eps )
285
309
self .ffn_norm = nn .RMSNorm (model_args .dim , eps = model_args .norm_eps )
286
- self .moe_enabled = layer_id < model_args .n_dense_layers
310
+ self .moe_enabled = layer_id >= model_args .n_dense_layers
287
311
288
312
if self .moe_enabled :
289
313
self .moe = MoE (model_args )
290
314
else :
291
315
self .feed_forward = FeedForward (model_args .dim , model_args .inter_dim )
292
316
317
+ # TODO: Need to revisit the weight initialization for the TransformerBlock
318
+ self .weight_init_std = 0.02 / (2 * (layer_id + 1 )) ** 0.5
319
+
293
320
def forward (self , x : torch .Tensor , freqs_cis : torch .Tensor ):
294
321
"""
295
322
Forward pass for the Transformer block.
@@ -308,6 +335,15 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
308
335
x = x + self .feed_forward (self .ffn_norm (x ))
309
336
return x
310
337
338
+ def init_weights (self , buffer_device : torch .device ):
339
+ for norm in (self .attention_norm , self .ffn_norm ):
340
+ norm .reset_parameters ()
341
+ self .attention .init_weights (self .weight_init_std )
342
+ if self .moe_enabled :
343
+ self .moe .init_weights (self .weight_init_std , buffer_device )
344
+ else :
345
+ self .feed_forward .init_weights (self .weight_init_std )
346
+
311
347
312
348
class DeepSeekV3Model (nn .Module , ModelProtocol ):
313
349
"""
@@ -319,7 +355,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
319
355
self .max_seq_len = model_args .max_seq_len
320
356
self .tok_embeddings = nn .Embedding (model_args .vocab_size , model_args .dim )
321
357
self .register_buffer (
322
- "freqs_cis" , precompute_freqs_cis (model_args ), persistent = False
358
+ "freqs_cis" , precompute_freqs_cis (model_args ), persistent = True
323
359
)
324
360
325
361
self .layers = torch .nn .ModuleDict ()
@@ -328,10 +364,36 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
328
364
329
365
self .norm = nn .RMSNorm (model_args .dim )
330
366
self .output = nn .Linear (
331
- model_args .dim , model_args .vocab_size , dtype = torch .get_default_dtype ()
367
+ model_args .dim ,
368
+ model_args .vocab_size ,
369
+ dtype = torch .get_default_dtype (),
370
+ bias = False ,
332
371
)
372
+ self .model_args = model_args
333
373
self .init_weights ()
334
374
375
+ def init_weights (self , buffer_device : torch .device | None = None ) -> None :
376
+ buffer_device = buffer_device or self .freqs_cis .device
377
+ with torch .device (buffer_device ):
378
+ self .freqs_cis = precompute_freqs_cis (self .model_args )
379
+ if self .tok_embeddings is not None :
380
+ nn .init .normal_ (self .tok_embeddings .weight )
381
+ for layer in self .layers .values ():
382
+ if layer is not None :
383
+ layer .init_weights (buffer_device = buffer_device )
384
+ if self .norm is not None :
385
+ self .norm .reset_parameters ()
386
+ final_out_std = self .model_args .dim ** - 0.5
387
+ cutoff_factor = 3
388
+ if self .output is not None :
389
+ nn .init .trunc_normal_ (
390
+ self .output .weight ,
391
+ mean = 0.0 ,
392
+ std = final_out_std ,
393
+ a = - cutoff_factor * final_out_std ,
394
+ b = cutoff_factor * final_out_std ,
395
+ )
396
+
335
397
def forward (self , tokens : torch .Tensor ):
336
398
"""
337
399
Forward pass for the Transformer model.
@@ -347,8 +409,5 @@ def forward(self, tokens: torch.Tensor):
347
409
for layer in self .layers .values ():
348
410
h = layer (h , self .freqs_cis )
349
411
h = self .norm (h )
350
- output = self .output (h ) # (batch_size, seq_len, dim)
412
+ output = self .output (h )
351
413
return output
352
-
353
- def init_weights (self , buffer_device : torch .device | None = None ) -> None :
354
- pass
0 commit comments