Skip to content

Commit 850ddad

Browse files
committed
fix bug
1 parent 45649de commit 850ddad

File tree

4 files changed

+86
-20
lines changed

4 files changed

+86
-20
lines changed

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,23 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
152152
self.v_head_dim = model_args.v_head_dim
153153

154154
if self.q_lora_rank == 0:
155-
self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim)
155+
self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim, bias=False)
156156
else:
157-
self.wq_a = nn.Linear(self.dim, self.q_lora_rank)
157+
self.wq_a = nn.Linear(self.dim, self.q_lora_rank, bias=False)
158158
self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps)
159-
self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
160-
self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
159+
self.wq_b = nn.Linear(
160+
self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False
161+
)
162+
self.wkv_a = nn.Linear(
163+
self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False
164+
)
161165
self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps)
162166
self.wkv_b = nn.Linear(
163-
self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
167+
self.kv_lora_rank,
168+
self.n_heads * (self.qk_nope_head_dim + self.v_head_dim),
169+
bias=False,
164170
)
165-
self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)
171+
self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim, bias=False)
166172
self.softmax_scale = self.qk_head_dim**-0.5
167173

168174
if model_args.max_seq_len > model_args.original_seq_len:
@@ -192,8 +198,8 @@ def forward(
192198
if self.q_lora_rank == 0:
193199
q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim)
194200
else:
195-
q = self.wq_b(self.q_norm(self.wq_a(x)))
196-
201+
q = self.wq_a(x)
202+
q = self.wq_b(self.q_norm(q))
197203
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
198204
# local heads from sizes of q and kv as TP may have sharded them after
199205
# the above linear ops.
@@ -235,6 +241,24 @@ def forward(
235241
output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim)
236242
return self.wo(output) # (bsz, seqlen, dim)
237243

244+
def init_weights(self, init_std: float):
245+
linear_list = [
246+
self.wkv_a,
247+
self.wkv_b,
248+
]
249+
if self.q_lora_rank > 0:
250+
linear_list.extend([self.wq_a, self.wq_b])
251+
else:
252+
linear_list.append(self.wq)
253+
254+
for linear in linear_list:
255+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
256+
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
257+
258+
self.kv_norm.reset_parameters()
259+
if self.q_lora_rank > 0:
260+
self.q_norm.reset_parameters()
261+
238262

239263
class FeedForward(nn.Module):
240264
"""
@@ -266,7 +290,7 @@ def __init__(
266290
def forward(self, x: torch.Tensor) -> torch.Tensor:
267291
return self.w2(F.silu(self.w1(x)) * self.w3(x))
268292

269-
def init_weights(self, init_std: float):
293+
def init_weights(self, init_std: float = 0.02):
270294
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
271295
for linear in (self.w2, self.w3):
272296
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
@@ -283,13 +307,16 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
283307
self.attention = Attention(model_args)
284308
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
285309
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
286-
self.moe_enabled = layer_id < model_args.n_dense_layers
310+
self.moe_enabled = layer_id >= model_args.n_dense_layers
287311

288312
if self.moe_enabled:
289313
self.moe = MoE(model_args)
290314
else:
291315
self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim)
292316

317+
# TODO: Need to revisit the weight initialization for the TransformerBlock
318+
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
319+
293320
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
294321
"""
295322
Forward pass for the Transformer block.
@@ -308,6 +335,15 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
308335
x = x + self.feed_forward(self.ffn_norm(x))
309336
return x
310337

338+
def init_weights(self, buffer_device: torch.device):
339+
for norm in (self.attention_norm, self.ffn_norm):
340+
norm.reset_parameters()
341+
self.attention.init_weights(self.weight_init_std)
342+
if self.moe_enabled:
343+
self.moe.init_weights(self.weight_init_std, buffer_device)
344+
else:
345+
self.feed_forward.init_weights(self.weight_init_std)
346+
311347

312348
class DeepSeekV3Model(nn.Module, ModelProtocol):
313349
"""
@@ -319,7 +355,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
319355
self.max_seq_len = model_args.max_seq_len
320356
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
321357
self.register_buffer(
322-
"freqs_cis", precompute_freqs_cis(model_args), persistent=False
358+
"freqs_cis", precompute_freqs_cis(model_args), persistent=True
323359
)
324360

325361
self.layers = torch.nn.ModuleDict()
@@ -328,10 +364,36 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
328364

329365
self.norm = nn.RMSNorm(model_args.dim)
330366
self.output = nn.Linear(
331-
model_args.dim, model_args.vocab_size, dtype=torch.get_default_dtype()
367+
model_args.dim,
368+
model_args.vocab_size,
369+
dtype=torch.get_default_dtype(),
370+
bias=False,
332371
)
372+
self.model_args = model_args
333373
self.init_weights()
334374

375+
def init_weights(self, buffer_device: torch.device | None = None) -> None:
376+
buffer_device = buffer_device or self.freqs_cis.device
377+
with torch.device(buffer_device):
378+
self.freqs_cis = precompute_freqs_cis(self.model_args)
379+
if self.tok_embeddings is not None:
380+
nn.init.normal_(self.tok_embeddings.weight)
381+
for layer in self.layers.values():
382+
if layer is not None:
383+
layer.init_weights(buffer_device=buffer_device)
384+
if self.norm is not None:
385+
self.norm.reset_parameters()
386+
final_out_std = self.model_args.dim**-0.5
387+
cutoff_factor = 3
388+
if self.output is not None:
389+
nn.init.trunc_normal_(
390+
self.output.weight,
391+
mean=0.0,
392+
std=final_out_std,
393+
a=-cutoff_factor * final_out_std,
394+
b=cutoff_factor * final_out_std,
395+
)
396+
335397
def forward(self, tokens: torch.Tensor):
336398
"""
337399
Forward pass for the Transformer model.
@@ -347,8 +409,5 @@ def forward(self, tokens: torch.Tensor):
347409
for layer in self.layers.values():
348410
h = layer(h, self.freqs_cis)
349411
h = self.norm(h)
350-
output = self.output(h) # (batch_size, seq_len, dim)
412+
output = self.output(h)
351413
return output
352-
353-
def init_weights(self, buffer_device: torch.device | None = None) -> None:
354-
pass

torchtitan/models/deepseek_v3/model/moe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def forward(
8989

9090
return out
9191

92+
def init_weights(self, init_std: float):
93+
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
94+
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
95+
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
96+
9297

9398
class TokenChoiceTopKRouter(nn.Module):
9499
"""This class implements token-choice routing. In token-choice top-K routing, each token is
@@ -173,6 +178,9 @@ def forward(
173178

174179
return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
175180

181+
def init_weights(self, init_std: float):
182+
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
183+
176184

177185
class MoE(nn.Module):
178186
def __init__(self, model_args: DeepSeekV3ModelArgs):
@@ -231,7 +239,6 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
231239
if self.load_balance_coeff is not None and self.load_balance_coeff > 0:
232240
self.register_full_backward_hook(self._update_expert_bias)
233241

234-
# TODO: double check the bias update logic. It aligns with the paper.
235242
def _update_expert_bias(self, *_):
236243
expert_bias_delta = self.load_balance_coeff * torch.sign(
237244
self.tokens_per_expert.mean() - self.tokens_per_expert

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
5050
data_parallel_replicate_degree = 1
5151
data_parallel_shard_degree = -1
5252
fsdp_reshard_after_forward = "default" # default / never / always
53-
tensor_parallel_degree = 1
53+
tensor_parallel_degree = 2
5454
enable_async_tensor_parallel = false
5555

5656
[checkpoint]

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ decay_type = "linear"
3838
lr_min = 0.0
3939

4040
[training]
41-
local_batch_size = 32
41+
local_batch_size = 16
4242
seq_len = 2048
4343
max_norm = 1.0 # grad norm clipping
4444
steps = 10
@@ -49,7 +49,7 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
4949
data_parallel_replicate_degree = 1
5050
data_parallel_shard_degree = -1
5151
fsdp_reshard_after_forward = "default" # default / never / always
52-
tensor_parallel_degree = 1
52+
tensor_parallel_degree = 2
5353
enable_async_tensor_parallel = false
5454

5555
[checkpoint]

0 commit comments

Comments
 (0)