Skip to content

Commit 667547b

Browse files
authored
support chunk_prefill in MTP (#2705)
1 parent b38823b commit 667547b

File tree

3 files changed

+64
-6
lines changed

3 files changed

+64
-6
lines changed

fastdeploy/spec_decode/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,13 @@ def _run_impl(self, *args, **kwargs) -> Any:
6161
Implemention for different method
6262
"""
6363
raise NotImplementedError
64+
65+
def is_chunk_prefill_enabled(self) -> bool:
66+
"""
67+
Check whether chunk-based prefill is enabled.
68+
Default is False.
69+
70+
Returns:
71+
bool: True if chunk prefill is enabled; False otherwise.
72+
"""
73+
return False

fastdeploy/spec_decode/mtp.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,17 +405,21 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
405405
1:length]
406406
self.model_inputs["pre_ids"][idx:idx + 1] = -1
407407
self.model_inputs["step_idx"][idx:idx + 1] = 0
408-
# TODO(liuzichang) finish chunked_prefill
409408
if self.parallel_config.enable_chunked_prefill:
410-
raise NotImplementedError(
411-
"MTP don't support chunked_prefill now")
409+
token_chunk_size = request.prefill_chunk_info[0]
410+
self.model_inputs["seq_lens_encoder"][idx:idx +
411+
1] = token_chunk_size
412+
self.model_inputs["seq_lens_this_time"][
413+
idx:idx + 1] = token_chunk_size
412414
else:
413415
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length
414-
self.model_inputs["seq_lens_decoder"][idx:idx + 1] = (
415-
request.get("seq_lens_decoder", 0))
416416
self.model_inputs["seq_lens_this_time"][idx:idx +
417417
1] = length
418418

419+
self.model_inputs["seq_lens_decoder"][idx:idx +
420+
1] = (request.get(
421+
"seq_lens_decoder",
422+
0))
419423
self.model_inputs["stop_flags"][idx:idx + 1] = False
420424
self.model_inputs["batch_drop"][idx:idx + 1] = False
421425

@@ -578,7 +582,6 @@ def _propose(self, target_hidden_states):
578582
self.model_inputs["output_padding_offset"],
579583
self.parallel_config.max_model_len,
580584
)
581-
paddle.device.synchronize()
582585

583586
# 4. Compute logits, Sample
584587
logits = self.model.compute_logits(hiddden_states)
@@ -595,6 +598,43 @@ def _propose(self, target_hidden_states):
595598

596599
self._post_process(sampled_token_ids)
597600

601+
def update_task_chunk_prefill(self, task):
602+
"""
603+
Update single task's chunk_prefill info
604+
"""
605+
idx = task.idx
606+
start_idx = sum(task.prefill_chunk_info[:task.chunk_idx])
607+
608+
if task.chunk_idx == len(task.prefill_chunk_info):
609+
self.model_inputs['seq_lens_encoder'][idx:idx + 1] = 0
610+
self.model_inputs["step_idx"][idx:idx + 1] = 1
611+
self.model_inputs["seq_lens_decoder"][idx:idx +
612+
1] = start_idx + task.get(
613+
"seq_lens_decoder", 0)
614+
else:
615+
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
616+
617+
if task.chunk_idx < len(task.prefill_chunk_info) - 1:
618+
self.model_inputs['input_ids'][
619+
idx, :token_chunk_size] = np.array(
620+
task.prompt_token_ids[start_idx + 1:start_idx +
621+
token_chunk_size + 1])
622+
# Last prefill
623+
else:
624+
self.model_inputs['input_ids'][
625+
idx, :token_chunk_size - 1] = np.array(
626+
task.prompt_token_ids[start_idx + 1:start_idx +
627+
token_chunk_size])
628+
629+
self.model_inputs["seq_lens_this_time"][idx:idx +
630+
1] = token_chunk_size
631+
self.model_inputs['seq_lens_encoder'][idx:idx +
632+
1] = token_chunk_size
633+
self.model_inputs["step_idx"][idx:idx + 1] = 0
634+
self.model_inputs["seq_lens_decoder"][idx:idx +
635+
1] = start_idx + task.get(
636+
"seq_lens_decoder", 0)
637+
598638
def _update_status(self):
599639
"""
600640
Update main-model's forward info in next step.
@@ -624,6 +664,11 @@ def _update_status(self):
624664
)
625665

626666
def _run_impl(self, full_hidden_states):
667+
""""""
627668
target_hidden_states = self._prepare_inputs(full_hidden_states)
628669
self._propose(target_hidden_states=target_hidden_states)
629670
self._update_status()
671+
672+
def is_chunk_prefill_enabled(self):
673+
""""""
674+
return True

fastdeploy/worker/gpu_model_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,9 @@ def _update_chunked_prefill(self, tasks):
898898
self.share_inputs["step_idx"][idx:idx + 1] = 0
899899
self.share_inputs["seq_lens_decoder"][
900900
idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
901+
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(
902+
):
903+
self.proposer.update_task_chunk_prefill(task)
901904
task.chunk_idx += 1
902905

903906
def _dummy_sampler_run(self) -> paddle.Tensor:

0 commit comments

Comments
 (0)