@@ -178,41 +178,32 @@ def forward(
178
178
torch.Tensor: Output tensor with the same shape as the input.
179
179
"""
180
180
bsz , seqlen , _ = x .size ()
181
+
182
+ # Query projection
181
183
if self .q_lora_rank == 0 :
182
184
q = self .wq (x ) # (bsz, seqlen, n_heads * qk_head_dim)
183
185
else :
184
- q = self .wq_b (
185
- self .q_norm (self .wq_a (x ))
186
- ) # (bsz, seqlen, n_heads * qk_head_dim)
186
+ q = self .wq_b (self .q_norm (self .wq_a (x )))
187
187
188
- q = q .view (
189
- bsz , seqlen , self .n_heads , self .qk_head_dim
190
- ) # (bsz, seqlen, n_heads, qk_head_dim)
188
+ q = q .view (bsz , seqlen , self .n_heads , self .qk_head_dim )
191
189
q_nope , q_pe = torch .split (
192
190
q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1
193
191
)
194
- # q_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
195
- # q_pe: (bsz, seqlen, n_heads, qk_rope_head_dim)
196
192
q_pe = apply_rotary_emb (q_pe , freqs_cis )
197
193
q = torch .cat ([q_nope , q_pe ], dim = - 1 ) # (bsz, seqlen, n_heads, qk_head_dim)
198
194
199
- kv = self .wkv_a (x ) # kv: (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
195
+ # Key-value projection
196
+ kv = self .wkv_a (x ) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
200
197
kv , k_pe = torch .split (kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
201
- # kv: (bsz, seqlen, kv_lora_rank)
202
- # k_pe: (bsz, seqlen, qk_rope_head_dim)
203
198
k_pe = apply_rotary_emb (
204
199
k_pe .unsqueeze (2 ), freqs_cis
205
200
) # (bsz, seqlen, 1, qk_rope_head_dim)
206
201
207
202
kv = self .wkv_b (
208
203
self .kv_norm (kv )
209
204
) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
210
- kv = kv .view (
211
- bsz , seqlen , self .n_heads , self .qk_nope_head_dim + self .v_head_dim
212
- ) # (bsz, seqlen, n_heads, qk_nope_head_dim + v_head_dim)
205
+ kv = kv .view (bsz , seqlen , self .n_heads , self .qk_nope_head_dim + self .v_head_dim )
213
206
k_nope , v = torch .split (kv , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
214
- # k_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
215
- # v: (bsz, seqlen, n_heads, v_head_dim)
216
207
k = torch .cat (
217
208
[k_nope , k_pe .expand (- 1 , - 1 , self .n_heads , - 1 )], dim = - 1
218
209
) # (bsz, seqlen, n_heads, qk_head_dim)
@@ -222,10 +213,9 @@ def forward(
222
213
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17
223
214
output = self .sdpa (q , k , v )
224
215
225
- output = output .transpose (
226
- 1 , 2
227
- ).contiguous () # (bs, seqlen, n_heads, v_head_dim)
228
- output = output .view (bsz , seqlen , - 1 ) # (bs, seqlen, n_heads * v_head_dim)
216
+ # Reshape and project output
217
+ output = output .transpose (1 , 2 ) # (bsz, seqlen, n_heads, v_head_dim)
218
+ output = output .view (bsz , seqlen , - 1 ) # (bsz, seqlen, n_heads * v_head_dim)
229
219
return self .wo (output ) # (bsz, seqlen, dim)
230
220
231
221
@@ -327,11 +317,6 @@ def forward(self, tokens: torch.Tensor):
327
317
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
328
318
"""
329
319
h = self .tok_embeddings (tokens )
330
- # # This is casual mask, which is already handled in sdpa
331
- # if seqlen > 1:
332
- # mask = torch.full(
333
- # (seqlen, seqlen), float("-inf"), device=tokens.device
334
- # ).triu_(1)
335
320
for layer in self .layers :
336
321
h = layer (h , self .freqs_cis )
337
322
h = self .norm (h )[:, - 1 ]
0 commit comments