Skip to content

Commit 24f934f

Browse files
authored
[BugFix] Fix low prediction accuracy of deepseekv3 (#2798)
1 parent 1e2319c commit 24f934f

File tree

2 files changed

+34
-39
lines changed

2 files changed

+34
-39
lines changed

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
from fastdeploy.model_executor.layers.attention.attention import Attention
4242
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
4343
AttentionBackend, AttentionMetadata)
44-
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
44+
from fastdeploy.model_executor.layers.attention.utils import \
45+
init_rank_and_device_id
4546
from fastdeploy.worker.forward_meta import ForwardMeta
4647

4748

@@ -185,6 +186,8 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
185186
# MLA
186187
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
187188
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
189+
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
190+
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
188191

189192
# pd_disaggregation
190193
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -375,9 +378,6 @@ def forward_mixed(
375378
speculate_decoder = self.speculative_method is not None
376379
speculate_max_tokens = self.speculate_max_draft_token_num
377380

378-
decode_stage = forward_meta.is_decode_batch
379-
prefill_stage = not (forward_meta.is_decode_batch)
380-
381381
if self.use_pd_disaggregation:
382382
metadata.kv_signal_data_list[
383383
layer.layer_id] = init_signal_layerwise(
@@ -387,8 +387,7 @@ def forward_mixed(
387387
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
388388
forward_meta, 'caches') else None
389389

390-
if prefill_stage:
391-
# 写入缓存
390+
if k is not None:
392391
prefill_mla_write_cache(
393392
compressed_kv,
394393
k_pe,
@@ -419,8 +418,7 @@ def forward_mixed(
419418
return fmha_out
420419

421420
# Decode
422-
if decode_stage:
423-
# mla写入缓存
421+
if k is None:
424422
decode_mla_write_cache(
425423
compressed_kv,
426424
k_pe,

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,7 @@ def forward(
317317
],
318318
dtype=layernorm_out.dtype)
319319

320-
decode_stage = forward_meta.is_decode_batch
321-
prefill_stage = not (forward_meta.is_decode_batch)
322-
323-
if prefill_stage:
320+
if forward_meta.max_enc_len_this_time:
324321
query = self.q_a_proj(layernorm_out)
325322
query = self.q_a_layernorm(query)
326323
query = self.q_b_proj(query)
@@ -370,8 +367,7 @@ def forward(
370367
fmha_out_prefill.dtype)
371368

372369
fmha_out = fmha_out + fmha_out_prefill
373-
374-
if decode_stage:
370+
if forward_meta.max_dec_len_this_time:
375371
query = self.q_a_proj(layernorm_out)
376372
query = self.q_a_layernorm(query)
377373
ln_out_or_q_c = query
@@ -554,28 +550,6 @@ def __init__(
554550
prefix="deepseek_v3.norm",
555551
)
556552

557-
def pre_process(self, forward_meta):
558-
"""
559-
"""
560-
seq_lens_encoder = forward_meta.seq_lens_encoder
561-
seq_lens_decoder = forward_meta.seq_lens_decoder
562-
seq_lens_this_time = forward_meta.seq_lens_this_time
563-
position_ids_shape = paddle.sum(seq_lens_this_time)
564-
565-
position_ids = paddle.empty(shape=position_ids_shape,
566-
dtype=seq_lens_encoder.dtype)
567-
mask_encoder_batch = paddle.empty(
568-
shape=position_ids_shape,
569-
dtype=seq_lens_encoder.dtype).unsqueeze(1)
570-
571-
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
572-
seq_lens_decoder,
573-
seq_lens_this_time,
574-
position_ids,
575-
mask_encoder_batch)
576-
577-
return position_ids, mask_encoder_batch
578-
579553
def load_state_dict(self, state_dict):
580554
"""
581555
Load model parameters from a given state dictionary.
@@ -590,13 +564,13 @@ def forward(
590564
self,
591565
ids_remove_padding: paddle.Tensor,
592566
forward_meta: ForwardMeta,
567+
position_ids: paddle.Tensor,
568+
mask_encoder_batch: paddle.Tensor,
593569
):
594570
"""
595571
"""
596572
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
597573

598-
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
599-
600574
residual = None
601575
for i in range(self.num_layers):
602576
hidden_states, residual = self.decoder_layers[i](
@@ -650,14 +624,37 @@ def compute_logits(self, hidden_states: paddle.Tensor):
650624
logits[:, self.ori_vocab_size:] = -float("inf")
651625
return logits
652626

627+
def pre_process(self, forward_meta):
628+
"""
629+
"""
630+
seq_lens_encoder = forward_meta.seq_lens_encoder
631+
seq_lens_decoder = forward_meta.seq_lens_decoder
632+
seq_lens_this_time = forward_meta.seq_lens_this_time
633+
position_ids_shape = paddle.sum(seq_lens_this_time)
634+
position_ids = paddle.empty(shape=position_ids_shape,
635+
dtype=seq_lens_encoder.dtype)
636+
mask_encoder_batch = paddle.empty(
637+
shape=position_ids_shape,
638+
dtype=seq_lens_encoder.dtype).unsqueeze(1)
639+
640+
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
641+
seq_lens_decoder,
642+
seq_lens_this_time,
643+
position_ids,
644+
mask_encoder_batch)
645+
646+
return position_ids, mask_encoder_batch
647+
653648
def forward(
654649
self,
655650
ids_remove_padding: paddle.Tensor,
656651
forward_meta: ForwardMeta,
657652
):
658653
"""
659654
"""
660-
hidden_states = self.model(ids_remove_padding, forward_meta)
655+
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
656+
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta,
657+
position_ids=position_ids, mask_encoder_batch=mask_encoder_batch)
661658
return hidden_states
662659

663660

0 commit comments

Comments
 (0)