Skip to content

Commit 722f6e2

Browse files
committed
clean up
1 parent 2a4bc36 commit 722f6e2

File tree

1 file changed

+10
-25
lines changed
  • torchtitan/models/deepseek-v3/model

1 file changed

+10
-25
lines changed

torchtitan/models/deepseek-v3/model/model.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -178,41 +178,32 @@ def forward(
178178
torch.Tensor: Output tensor with the same shape as the input.
179179
"""
180180
bsz, seqlen, _ = x.size()
181+
182+
# Query projection
181183
if self.q_lora_rank == 0:
182184
q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim)
183185
else:
184-
q = self.wq_b(
185-
self.q_norm(self.wq_a(x))
186-
) # (bsz, seqlen, n_heads * qk_head_dim)
186+
q = self.wq_b(self.q_norm(self.wq_a(x)))
187187

188-
q = q.view(
189-
bsz, seqlen, self.n_heads, self.qk_head_dim
190-
) # (bsz, seqlen, n_heads, qk_head_dim)
188+
q = q.view(bsz, seqlen, self.n_heads, self.qk_head_dim)
191189
q_nope, q_pe = torch.split(
192190
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
193191
)
194-
# q_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
195-
# q_pe: (bsz, seqlen, n_heads, qk_rope_head_dim)
196192
q_pe = apply_rotary_emb(q_pe, freqs_cis)
197193
q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim)
198194

199-
kv = self.wkv_a(x) # kv: (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
195+
# Key-value projection
196+
kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
200197
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
201-
# kv: (bsz, seqlen, kv_lora_rank)
202-
# k_pe: (bsz, seqlen, qk_rope_head_dim)
203198
k_pe = apply_rotary_emb(
204199
k_pe.unsqueeze(2), freqs_cis
205200
) # (bsz, seqlen, 1, qk_rope_head_dim)
206201

207202
kv = self.wkv_b(
208203
self.kv_norm(kv)
209204
) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
210-
kv = kv.view(
211-
bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim
212-
) # (bsz, seqlen, n_heads, qk_nope_head_dim + v_head_dim)
205+
kv = kv.view(bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)
213206
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
214-
# k_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
215-
# v: (bsz, seqlen, n_heads, v_head_dim)
216207
k = torch.cat(
217208
[k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1
218209
) # (bsz, seqlen, n_heads, qk_head_dim)
@@ -222,10 +213,9 @@ def forward(
222213
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17
223214
output = self.sdpa(q, k, v)
224215

225-
output = output.transpose(
226-
1, 2
227-
).contiguous() # (bs, seqlen, n_heads, v_head_dim)
228-
output = output.view(bsz, seqlen, -1) # (bs, seqlen, n_heads * v_head_dim)
216+
# Reshape and project output
217+
output = output.transpose(1, 2) # (bsz, seqlen, n_heads, v_head_dim)
218+
output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim)
229219
return self.wo(output) # (bsz, seqlen, dim)
230220

231221

@@ -327,11 +317,6 @@ def forward(self, tokens: torch.Tensor):
327317
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
328318
"""
329319
h = self.tok_embeddings(tokens)
330-
# # This is casual mask, which is already handled in sdpa
331-
# if seqlen > 1:
332-
# mask = torch.full(
333-
# (seqlen, seqlen), float("-inf"), device=tokens.device
334-
# ).triu_(1)
335320
for layer in self.layers:
336321
h = layer(h, self.freqs_cis)
337322
h = self.norm(h)[:, -1]

0 commit comments

Comments
 (0)