@@ -317,10 +317,7 @@ def forward(
317
317
],
318
318
dtype = layernorm_out .dtype )
319
319
320
- decode_stage = forward_meta .is_decode_batch
321
- prefill_stage = not (forward_meta .is_decode_batch )
322
-
323
- if prefill_stage :
320
+ if forward_meta .max_enc_len_this_time :
324
321
query = self .q_a_proj (layernorm_out )
325
322
query = self .q_a_layernorm (query )
326
323
query = self .q_b_proj (query )
@@ -370,8 +367,7 @@ def forward(
370
367
fmha_out_prefill .dtype )
371
368
372
369
fmha_out = fmha_out + fmha_out_prefill
373
-
374
- if decode_stage :
370
+ if forward_meta .max_dec_len_this_time :
375
371
query = self .q_a_proj (layernorm_out )
376
372
query = self .q_a_layernorm (query )
377
373
ln_out_or_q_c = query
@@ -554,28 +550,6 @@ def __init__(
554
550
prefix = "deepseek_v3.norm" ,
555
551
)
556
552
557
- def pre_process (self , forward_meta ):
558
- """
559
- """
560
- seq_lens_encoder = forward_meta .seq_lens_encoder
561
- seq_lens_decoder = forward_meta .seq_lens_decoder
562
- seq_lens_this_time = forward_meta .seq_lens_this_time
563
- position_ids_shape = paddle .sum (seq_lens_this_time )
564
-
565
- position_ids = paddle .empty (shape = position_ids_shape ,
566
- dtype = seq_lens_encoder .dtype )
567
- mask_encoder_batch = paddle .empty (
568
- shape = position_ids_shape ,
569
- dtype = seq_lens_encoder .dtype ).unsqueeze (1 )
570
-
571
- get_position_ids_and_mask_encoder_batch (seq_lens_encoder ,
572
- seq_lens_decoder ,
573
- seq_lens_this_time ,
574
- position_ids ,
575
- mask_encoder_batch )
576
-
577
- return position_ids , mask_encoder_batch
578
-
579
553
def load_state_dict (self , state_dict ):
580
554
"""
581
555
Load model parameters from a given state dictionary.
@@ -590,13 +564,13 @@ def forward(
590
564
self ,
591
565
ids_remove_padding : paddle .Tensor ,
592
566
forward_meta : ForwardMeta ,
567
+ position_ids : paddle .Tensor ,
568
+ mask_encoder_batch : paddle .Tensor ,
593
569
):
594
570
"""
595
571
"""
596
572
hidden_states = self .embeddings (ids_remove_padding = ids_remove_padding )
597
573
598
- position_ids , mask_encoder_batch = self .pre_process (forward_meta )
599
-
600
574
residual = None
601
575
for i in range (self .num_layers ):
602
576
hidden_states , residual = self .decoder_layers [i ](
@@ -650,14 +624,37 @@ def compute_logits(self, hidden_states: paddle.Tensor):
650
624
logits [:, self .ori_vocab_size :] = - float ("inf" )
651
625
return logits
652
626
627
+ def pre_process (self , forward_meta ):
628
+ """
629
+ """
630
+ seq_lens_encoder = forward_meta .seq_lens_encoder
631
+ seq_lens_decoder = forward_meta .seq_lens_decoder
632
+ seq_lens_this_time = forward_meta .seq_lens_this_time
633
+ position_ids_shape = paddle .sum (seq_lens_this_time )
634
+ position_ids = paddle .empty (shape = position_ids_shape ,
635
+ dtype = seq_lens_encoder .dtype )
636
+ mask_encoder_batch = paddle .empty (
637
+ shape = position_ids_shape ,
638
+ dtype = seq_lens_encoder .dtype ).unsqueeze (1 )
639
+
640
+ get_position_ids_and_mask_encoder_batch (seq_lens_encoder ,
641
+ seq_lens_decoder ,
642
+ seq_lens_this_time ,
643
+ position_ids ,
644
+ mask_encoder_batch )
645
+
646
+ return position_ids , mask_encoder_batch
647
+
653
648
def forward (
654
649
self ,
655
650
ids_remove_padding : paddle .Tensor ,
656
651
forward_meta : ForwardMeta ,
657
652
):
658
653
"""
659
654
"""
660
- hidden_states = self .model (ids_remove_padding , forward_meta )
655
+ position_ids , mask_encoder_batch = self .pre_process (forward_meta )
656
+ hidden_states = self .model (ids_remove_padding = ids_remove_padding , forward_meta = forward_meta ,
657
+ position_ids = position_ids , mask_encoder_batch = mask_encoder_batch )
661
658
return hidden_states
662
659
663
660
0 commit comments