@@ -220,6 +220,8 @@ def __init__(self, config: DeepseekV2Config):
220
220
self .num_heads * (self .q_head_dim - self .qk_rope_head_dim + self .v_head_dim ),
221
221
bias = False ,
222
222
)
223
+ self .w_uk = nn .Parameter ((self .num_heads , config .kv_lora_rank , self .qk_nope_head_dim ))
224
+ self .w_uv = nn .Parameter ((self .num_heads , self .v_head_dim , config .kv_lora_rank ))
223
225
224
226
self .o_proj = nn .Linear (
225
227
self .num_heads * self .v_head_dim ,
@@ -241,6 +243,106 @@ def forward(
241
243
paged_kv_cache : PagedKVCache ,
242
244
layer_id : int ,
243
245
query_positions : Tensor ,
246
+ ):
247
+ return self .forward_absorb (hidden_states , paged_kv_cache , layer_id , query_positions )
248
+
249
+ def forward_absorb (
250
+ self ,
251
+ hidden_states : Tensor ,
252
+ paged_kv_cache : PagedKVCache ,
253
+ layer_id : int ,
254
+ query_positions : Tensor ,
255
+ ):
256
+ b , s , _ = hidden_states .shape
257
+
258
+ if self .q_lora_rank is None :
259
+ q = self .q_proj (hidden_states )
260
+ else :
261
+ q = self .q_b_proj (
262
+ self .q_a_layernorm (self .q_a_proj (hidden_states ))
263
+ ) # (b, s, num_heads * q_head_dim)
264
+ q = op .reshape (q , (b , s , self .num_heads , self .q_head_dim )) # (b, s, num_heads, q_head_dim)
265
+ q_nope , q_pe = op .split (
266
+ q , [self .qk_nope_head_dim ], axis = - 1
267
+ ) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, qk_rope_head_dim)
268
+ q_nope = (
269
+ op .matmul (
270
+ q_nope .reshape (b * s , self .num_heads , self .qk_nope_head_dim ).permute_dims (1 , 0 , 2 ),
271
+ self .w_uk .permute_dims (0 , 2 , 1 ),
272
+ )
273
+ .permute_dims (1 , 0 , 2 )
274
+ .reshape (b , s , self .num_heads , self .kv_lora_rank )
275
+ ) # (b, s, num_heads, kv_lora_rank)
276
+
277
+ compressed_kv = self .kv_a_proj_with_mqa (hidden_states ).reshape (
278
+ b , s , 1 , self .kv_lora_rank + self .qk_rope_head_dim
279
+ ) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
280
+ compressed_kv , k_pe = op .split (
281
+ compressed_kv , [self .config .kv_lora_rank ], axis = - 1
282
+ ) # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)
283
+
284
+ compressed_kv = self .kv_a_layernorm (compressed_kv )
285
+ k_nope = compressed_kv # (b, s, 1, kv_lora_rank)
286
+ value_states = compressed_kv # (b, s, 1, kv_lora_rank)
287
+
288
+ q_pe , k_pe = self .rotary_emb (q_pe , k_pe , query_positions )
289
+
290
+ def concat_nope_pe (num_heads : int ):
291
+ def f_concat_nope_pe (var_nope : te .Tensor , var_pe : te .Tensor ):
292
+ return te .compute (
293
+ (b , s , num_heads , self .kv_lora_rank + self .qk_rope_head_dim ),
294
+ lambda _b , _s , _h , _d : te .if_then_else (
295
+ _d < self .kv_lora_rank ,
296
+ var_nope [_b , _s , _h , _d ],
297
+ var_pe [_b , _s , _h , _d - self .kv_lora_rank ],
298
+ ),
299
+ )
300
+
301
+ return f_concat_nope_pe
302
+
303
+ query_states = op .tensor_expr_op (
304
+ concat_nope_pe (num_heads = self .num_heads ), "concat_q" , [q_nope , q_pe ]
305
+ ) # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)
306
+ key_states = op .tensor_expr_op (
307
+ concat_nope_pe (num_heads = 1 ), "concat_k" , [k_nope , k_pe ]
308
+ ) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
309
+ value_states = op .pad (
310
+ value_states , [0 , 0 , 0 , 0 , 0 , 0 , 0 , self .qk_rope_head_dim ]
311
+ ) # (b, s, 1, kv_lora_rank + qk_rope_head_dim)
312
+
313
+ qkv = op .concat (
314
+ [query_states , key_states , value_states ], dim = 2
315
+ ) # (b, s, num_heads + 2, kv_lora_rank + qk_rope_head_dim)
316
+ output , _ = op .split (
317
+ paged_kv_cache .attention_with_fused_qkv (
318
+ layer_id ,
319
+ qkv ,
320
+ self .num_heads ,
321
+ self .softmax_scale
322
+ * math .sqrt (
323
+ self .kv_lora_rank + self .qk_rope_head_dim
324
+ ), # This is to cancel out the 1/sqrt(d) in normal attention
325
+ ),
326
+ indices_or_sections = [self .kv_lora_rank ],
327
+ axis = - 1 ,
328
+ ) # (b, s, num_heads, kv_lora_rank)
329
+ output = (
330
+ op .matmul (
331
+ output .reshape (b * s , self .num_heads , self .kv_lora_rank ).permute_dims (1 , 0 , 2 ),
332
+ self .w_uv .permute_dims (0 , 2 , 1 ),
333
+ )
334
+ .permute_dims (1 , 0 , 2 )
335
+ .reshape (b , s , self .num_heads * self .v_head_dim )
336
+ )
337
+
338
+ return self .o_proj (output )
339
+
340
+ def forward_normal (
341
+ self ,
342
+ hidden_states : Tensor ,
343
+ paged_kv_cache : PagedKVCache ,
344
+ layer_id : int ,
345
+ query_positions : Tensor ,
244
346
):
245
347
b , s , _ = hidden_states .shape
246
348
@@ -450,6 +552,14 @@ def _set(layer, hint):
450
552
self .self_attn .kv_b_proj .weight ,
451
553
tp .ShardSingleDim ("_shard_kv_b_weight" , dim = 0 ),
452
554
)
555
+ _set (
556
+ self .self_attn .w_uk ,
557
+ tp .ShardSingleDim ("_shard_kv_b_weight_w_uk" , dim = 0 ),
558
+ )
559
+ _set (
560
+ self .self_attn .w_uv ,
561
+ tp .ShardSingleDim ("_shard_kv_b_weight_w_uv" , dim = 0 ),
562
+ )
453
563
_set (self .self_attn .o_proj .weight , tp .ShardSingleDim ("_shard_o" , dim = 1 ))
454
564
455
565
if isinstance (self .mlp , DeepseekV2MoE ):
@@ -517,7 +627,6 @@ def __init__(self, config: DeepseekV2Config):
517
627
518
628
def forward (self , inputs : Tensor , paged_kv_cache : PagedKVCache ):
519
629
hidden_states = inputs
520
- print (f"inputs.shape = { inputs .shape } " )
521
630
query_positions = paged_kv_cache .get_query_positions (inputs .shape [0 ] * inputs .shape [1 ])
522
631
for layer_id , layer in enumerate (self .layers ):
523
632
hidden_states = layer (hidden_states , paged_kv_cache , layer_id , query_positions )
@@ -535,6 +644,8 @@ def __init__(self, config: DeepseekV2Config):
535
644
self .intermediate_size = config .intermediate_size
536
645
self .num_attention_heads = config .num_attention_heads
537
646
self .num_key_value_heads = config .num_key_value_heads
647
+ self .kv_lora_rank = config .kv_lora_rank
648
+ self .qk_rope_head_dim = config .qk_rope_head_dim
538
649
self .rms_norm_eps = config .rms_norm_eps
539
650
self .rope_theta = config .rope_theta
540
651
self .vocab_size = config .vocab_size
@@ -621,8 +732,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
621
732
support_sliding_window = support_sliding_window ,
622
733
num_hidden_layers = self .num_hidden_layers ,
623
734
num_attention_heads = self .num_attention_heads // self .tensor_parallel_shards ,
624
- num_key_value_heads = self . num_key_value_heads // self . tensor_parallel_shards ,
625
- head_dim = 256 ,
735
+ num_key_value_heads = 1 ,
736
+ head_dim = self . kv_lora_rank + self . qk_rope_head_dim ,
626
737
rope_mode = RopeMode .NONE ,
627
738
rope_scale = 1 ,
628
739
rope_theta = self .rope_theta ,
0 commit comments