Skip to content

Commit e9f925f

Browse files
committed
add attention block
1 parent e7a0af1 commit e9f925f

File tree

1 file changed

+12
-9
lines changed
  • torchtitan/models/deepseek-v3/model

1 file changed

+12
-9
lines changed

torchtitan/models/deepseek-v3/model/model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,48 +174,51 @@ def forward(
174174
"""
175175
bsz, seqlen, _ = x.size()
176176
if self.q_lora_rank == 0:
177-
q = self.wq(x) # q: (bsz, seqlen, n_heads * qk_head_dim)
177+
q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim)
178178
else:
179179
q = self.wq_b(
180180
self.q_norm(self.wq_a(x))
181-
) # q: (bsz, seqlen, n_heads * qk_head_dim)
181+
) # (bsz, seqlen, n_heads * qk_head_dim)
182182

183183
q = q.view(
184184
bsz, seqlen, self.n_heads, self.qk_head_dim
185-
) # q: (bsz, seqlen, n_heads, qk_head_dim)
185+
) # (bsz, seqlen, n_heads, qk_head_dim)
186186
q_nope, q_pe = torch.split(
187187
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
188188
)
189189
# q_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
190190
# q_pe: (bsz, seqlen, n_heads, qk_rope_head_dim)
191191
q_pe = apply_rotary_emb(q_pe, freqs_cis)
192-
q = torch.cat([q_nope, q_pe], dim=-1) # q: (bsz, seqlen, n_heads, qk_head_dim)
192+
q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim)
193193

194194
kv = self.wkv_a(x) # kv: (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
195195
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
196196
# kv: (bsz, seqlen, kv_lora_rank)
197197
# k_pe: (bsz, seqlen, qk_rope_head_dim)
198198
k_pe = apply_rotary_emb(
199199
k_pe.unsqueeze(2), freqs_cis
200-
) # k_pe: (bsz, seqlen, 1, qk_rope_head_dim)
200+
) # (bsz, seqlen, 1, qk_rope_head_dim)
201201

202202
kv = self.wkv_b(
203203
self.kv_norm(kv)
204-
) # kv: (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
204+
) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
205205
kv = kv.view(
206206
bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim
207207
) # (bsz, seqlen, n_heads, qk_nope_head_dim + v_head_dim)
208208
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
209209
# k_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
210210
# v: (bsz, seqlen, n_heads, v_head_dim)
211-
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)
212-
# k: (bsz, seqlen, n_heads, qk_head_dim)
211+
k = torch.cat(
212+
[k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1
213+
) # (bsz, seqlen, n_heads, qk_head_dim)
213214

214215
# TODO: Need to pass softmax_scale to sdpa() interface.
215216
# For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa
216217
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17
217218
output = self.sdpa(q, k, v)
218219

219-
output = output.transpose(1, 2).contiguous()
220+
output = output.transpose(
221+
1, 2
222+
).contiguous() # (bs, seqlen, n_heads, v_head_dim)
220223
output = output.view(bsz, seqlen, -1) # (bs, seqlen, n_heads * v_head_dim)
221224
return self.wo(output) # (bsz, seqlen, dim)

0 commit comments

Comments
 (0)