@@ -174,48 +174,51 @@ def forward(
174
174
"""
175
175
bsz , seqlen , _ = x .size ()
176
176
if self .q_lora_rank == 0 :
177
- q = self .wq (x ) # q: (bsz, seqlen, n_heads * qk_head_dim)
177
+ q = self .wq (x ) # (bsz, seqlen, n_heads * qk_head_dim)
178
178
else :
179
179
q = self .wq_b (
180
180
self .q_norm (self .wq_a (x ))
181
- ) # q: (bsz, seqlen, n_heads * qk_head_dim)
181
+ ) # (bsz, seqlen, n_heads * qk_head_dim)
182
182
183
183
q = q .view (
184
184
bsz , seqlen , self .n_heads , self .qk_head_dim
185
- ) # q: (bsz, seqlen, n_heads, qk_head_dim)
185
+ ) # (bsz, seqlen, n_heads, qk_head_dim)
186
186
q_nope , q_pe = torch .split (
187
187
q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1
188
188
)
189
189
# q_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
190
190
# q_pe: (bsz, seqlen, n_heads, qk_rope_head_dim)
191
191
q_pe = apply_rotary_emb (q_pe , freqs_cis )
192
- q = torch .cat ([q_nope , q_pe ], dim = - 1 ) # q: (bsz, seqlen, n_heads, qk_head_dim)
192
+ q = torch .cat ([q_nope , q_pe ], dim = - 1 ) # (bsz, seqlen, n_heads, qk_head_dim)
193
193
194
194
kv = self .wkv_a (x ) # kv: (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
195
195
kv , k_pe = torch .split (kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
196
196
# kv: (bsz, seqlen, kv_lora_rank)
197
197
# k_pe: (bsz, seqlen, qk_rope_head_dim)
198
198
k_pe = apply_rotary_emb (
199
199
k_pe .unsqueeze (2 ), freqs_cis
200
- ) # k_pe: (bsz, seqlen, 1, qk_rope_head_dim)
200
+ ) # (bsz, seqlen, 1, qk_rope_head_dim)
201
201
202
202
kv = self .wkv_b (
203
203
self .kv_norm (kv )
204
- ) # kv: (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
204
+ ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
205
205
kv = kv .view (
206
206
bsz , seqlen , self .n_heads , self .qk_nope_head_dim + self .v_head_dim
207
207
) # (bsz, seqlen, n_heads, qk_nope_head_dim + v_head_dim)
208
208
k_nope , v = torch .split (kv , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
209
209
# k_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
210
210
# v: (bsz, seqlen, n_heads, v_head_dim)
211
- k = torch .cat ([k_nope , k_pe .expand (- 1 , - 1 , self .n_heads , - 1 )], dim = - 1 )
212
- # k: (bsz, seqlen, n_heads, qk_head_dim)
211
+ k = torch .cat (
212
+ [k_nope , k_pe .expand (- 1 , - 1 , self .n_heads , - 1 )], dim = - 1
213
+ ) # (bsz, seqlen, n_heads, qk_head_dim)
213
214
214
215
# TODO: Need to pass softmax_scale to sdpa() interface.
215
216
# For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa
216
217
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17
217
218
output = self .sdpa (q , k , v )
218
219
219
- output = output .transpose (1 , 2 ).contiguous ()
220
+ output = output .transpose (
221
+ 1 , 2
222
+ ).contiguous () # (bs, seqlen, n_heads, v_head_dim)
220
223
output = output .view (bsz , seqlen , - 1 ) # (bs, seqlen, n_heads * v_head_dim)
221
224
return self .wo (output ) # (bsz, seqlen, dim)
0 commit comments