@@ -405,17 +405,21 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
405
405
1 :length ]
406
406
self .model_inputs ["pre_ids" ][idx :idx + 1 ] = - 1
407
407
self .model_inputs ["step_idx" ][idx :idx + 1 ] = 0
408
- # TODO(liuzichang) finish chunked_prefill
409
408
if self .parallel_config .enable_chunked_prefill :
410
- raise NotImplementedError (
411
- "MTP don't support chunked_prefill now" )
409
+ token_chunk_size = request .prefill_chunk_info [0 ]
410
+ self .model_inputs ["seq_lens_encoder" ][idx :idx +
411
+ 1 ] = token_chunk_size
412
+ self .model_inputs ["seq_lens_this_time" ][
413
+ idx :idx + 1 ] = token_chunk_size
412
414
else :
413
415
self .model_inputs ["seq_lens_encoder" ][idx :idx + 1 ] = length
414
- self .model_inputs ["seq_lens_decoder" ][idx :idx + 1 ] = (
415
- request .get ("seq_lens_decoder" , 0 ))
416
416
self .model_inputs ["seq_lens_this_time" ][idx :idx +
417
417
1 ] = length
418
418
419
+ self .model_inputs ["seq_lens_decoder" ][idx :idx +
420
+ 1 ] = (request .get (
421
+ "seq_lens_decoder" ,
422
+ 0 ))
419
423
self .model_inputs ["stop_flags" ][idx :idx + 1 ] = False
420
424
self .model_inputs ["batch_drop" ][idx :idx + 1 ] = False
421
425
@@ -578,7 +582,6 @@ def _propose(self, target_hidden_states):
578
582
self .model_inputs ["output_padding_offset" ],
579
583
self .parallel_config .max_model_len ,
580
584
)
581
- paddle .device .synchronize ()
582
585
583
586
# 4. Compute logits, Sample
584
587
logits = self .model .compute_logits (hiddden_states )
@@ -595,6 +598,43 @@ def _propose(self, target_hidden_states):
595
598
596
599
self ._post_process (sampled_token_ids )
597
600
601
+ def update_task_chunk_prefill (self , task ):
602
+ """
603
+ Update single task's chunk_prefill info
604
+ """
605
+ idx = task .idx
606
+ start_idx = sum (task .prefill_chunk_info [:task .chunk_idx ])
607
+
608
+ if task .chunk_idx == len (task .prefill_chunk_info ):
609
+ self .model_inputs ['seq_lens_encoder' ][idx :idx + 1 ] = 0
610
+ self .model_inputs ["step_idx" ][idx :idx + 1 ] = 1
611
+ self .model_inputs ["seq_lens_decoder" ][idx :idx +
612
+ 1 ] = start_idx + task .get (
613
+ "seq_lens_decoder" , 0 )
614
+ else :
615
+ token_chunk_size = task .prefill_chunk_info [task .chunk_idx ]
616
+
617
+ if task .chunk_idx < len (task .prefill_chunk_info ) - 1 :
618
+ self .model_inputs ['input_ids' ][
619
+ idx , :token_chunk_size ] = np .array (
620
+ task .prompt_token_ids [start_idx + 1 :start_idx +
621
+ token_chunk_size + 1 ])
622
+ # Last prefill
623
+ else :
624
+ self .model_inputs ['input_ids' ][
625
+ idx , :token_chunk_size - 1 ] = np .array (
626
+ task .prompt_token_ids [start_idx + 1 :start_idx +
627
+ token_chunk_size ])
628
+
629
+ self .model_inputs ["seq_lens_this_time" ][idx :idx +
630
+ 1 ] = token_chunk_size
631
+ self .model_inputs ['seq_lens_encoder' ][idx :idx +
632
+ 1 ] = token_chunk_size
633
+ self .model_inputs ["step_idx" ][idx :idx + 1 ] = 0
634
+ self .model_inputs ["seq_lens_decoder" ][idx :idx +
635
+ 1 ] = start_idx + task .get (
636
+ "seq_lens_decoder" , 0 )
637
+
598
638
def _update_status (self ):
599
639
"""
600
640
Update main-model's forward info in next step.
@@ -624,6 +664,11 @@ def _update_status(self):
624
664
)
625
665
626
666
def _run_impl (self , full_hidden_states ):
667
+ """"""
627
668
target_hidden_states = self ._prepare_inputs (full_hidden_states )
628
669
self ._propose (target_hidden_states = target_hidden_states )
629
670
self ._update_status ()
671
+
672
+ def is_chunk_prefill_enabled (self ):
673
+ """"""
674
+ return True
0 commit comments