@@ -732,7 +732,7 @@ def _calc_spec_decode_metadata(
732
732
# [0, 1, 2, 5, 6, 9]
733
733
target_logits_indices += arange
734
734
735
- # TODO: Optimize the CPU -> GPU copy.
735
+ # TODO: Optimize the CPU -> NPU copy.
736
736
cu_num_draft_tokens = torch .from_numpy (cu_num_draft_tokens ).to (
737
737
self .device , non_blocking = True )
738
738
logits_indices = torch .from_numpy (logits_indices ).to (self .device ,
@@ -811,6 +811,92 @@ def apply_grammar_bitmask(
811
811
)
812
812
return logits .to (self .device ).to (logits_dtype )
813
813
814
+ def get_spec_token_ids (
815
+ self ,
816
+ valid_sampled_token_ids : list [list [int ]],
817
+ sampling_metadata : SamplingMetadata ,
818
+ scheduler_output : "SchedulerOutput" ,
819
+ spec_decode_metadata : SpecDecodeMetadata ,
820
+ positions : torch .Tensor ,
821
+ num_scheduled_tokens : int ,
822
+ hidden_states : torch .Tensor ,
823
+ attn_metadata : SpecDecodeMetadata ,
824
+ ) -> list [list [int ]]:
825
+ if not self .use_spec_decode :
826
+ # Speculative decoding is not enabled.
827
+ spec_token_ids = None
828
+ elif self .speculative_config .method == "ngram" :
829
+ assert isinstance (self .drafter , NgramProposer )
830
+ spec_token_ids = self .generate_draft_token_ids (
831
+ valid_sampled_token_ids , sampling_metadata )
832
+ elif self .speculative_config .method == "eagle" :
833
+ assert isinstance (self .drafter , EagleProposer )
834
+ # TODO(woosuk): Refactor the loop.
835
+ next_token_ids : list [int ] = []
836
+ for i , token_ids in enumerate (valid_sampled_token_ids ):
837
+ if token_ids :
838
+ # Common case.
839
+ next_token_id = token_ids [- 1 ]
840
+ else :
841
+ # Partial prefill (rare case).
842
+ # Get the next token id from the request state.
843
+ req_id = self .input_batch .req_ids [i ]
844
+ req_state = self .requests [req_id ]
845
+ seq_len = (req_state .num_computed_tokens +
846
+ scheduler_output .num_scheduled_tokens [req_id ])
847
+ next_token_id = req_state .get_token_id (seq_len )
848
+ next_token_ids .append (next_token_id )
849
+ next_token_ids = torch .tensor (next_token_ids ,
850
+ dtype = torch .int32 ,
851
+ device = self .device )
852
+
853
+ if spec_decode_metadata is None :
854
+ # input_ids can be None for multimodal models.
855
+ # We need to slice token_ids, positions, and hidden_states
856
+ # because the eagle head does not use cuda graph and should
857
+ # not include padding.
858
+ target_token_ids = self .input_ids [:num_scheduled_tokens ]
859
+ target_positions = positions [:num_scheduled_tokens ]
860
+ target_hidden_states = hidden_states [:num_scheduled_tokens ]
861
+ target_slot_mapping = attn_metadata .slot_mapping
862
+ cu_num_tokens = attn_metadata .query_start_loc
863
+ else :
864
+ # TODO(woosuk): Refactor this.
865
+ num_draft_tokens = spec_decode_metadata .num_draft_tokens
866
+ num_rejected_tokens = [
867
+ n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
868
+ for i , n in enumerate (num_draft_tokens )
869
+ ]
870
+ num_rejected_tokens = torch .tensor (
871
+ num_rejected_tokens ,
872
+ dtype = torch .int32 ,
873
+ device = self .device ,
874
+ )
875
+ cu_num_tokens , token_indices = self .drafter .prepare_inputs (
876
+ attn_metadata .query_start_loc ,
877
+ num_rejected_tokens ,
878
+ )
879
+ target_token_ids = self .input_ids [token_indices ]
880
+ target_positions = positions [token_indices ]
881
+ target_hidden_states = hidden_states [token_indices ]
882
+ target_slot_mapping = attn_metadata .slot_mapping [token_indices ]
883
+
884
+ draft_token_ids , draft_probs = self .drafter .propose (
885
+ target_token_ids = target_token_ids ,
886
+ target_positions = target_positions ,
887
+ target_hidden_states = target_hidden_states ,
888
+ target_slot_mapping = target_slot_mapping ,
889
+ next_token_ids = next_token_ids ,
890
+ cu_num_tokens = cu_num_tokens ,
891
+ block_table = attn_metadata .block_tables ,
892
+ sampling_metadata = sampling_metadata ,
893
+ )
894
+ spec_token_ids = draft_token_ids .tolist ()
895
+ # TODO(woosuk): Cache draft_probs and use it for rejection sampling
896
+ # in the next step.
897
+ del draft_probs
898
+ return spec_token_ids
899
+
814
900
@torch .inference_mode ()
815
901
def execute_model (
816
902
self ,
@@ -895,79 +981,16 @@ def execute_model(
895
981
self .input_batch .vocab_size ,
896
982
)
897
983
898
- if not self .use_spec_decode :
899
- # Speculative decoding is not enabled.
900
- spec_token_ids = None
901
- elif self .speculative_config .method == "ngram" :
902
- assert isinstance (self .drafter , NgramProposer )
903
- spec_token_ids = self .generate_draft_token_ids (
904
- valid_sampled_token_ids , sampling_metadata )
905
- elif self .speculative_config .method == "eagle" :
906
- assert isinstance (self .drafter , EagleProposer )
907
- # TODO(woosuk): Refactor the loop.
908
- next_token_ids : list [int ] = []
909
- for i , token_ids in enumerate (valid_sampled_token_ids ):
910
- if token_ids :
911
- # Common case.
912
- next_token_id = token_ids [- 1 ]
913
- else :
914
- # Partial prefill (rare case).
915
- # Get the next token id from the request state.
916
- req_id = self .input_batch .req_ids [i ]
917
- req_state = self .requests [req_id ]
918
- seq_len = (req_state .num_computed_tokens +
919
- scheduler_output .num_scheduled_tokens [req_id ])
920
- next_token_id = req_state .get_token_id (seq_len )
921
- next_token_ids .append (next_token_id )
922
- next_token_ids = torch .tensor (next_token_ids ,
923
- dtype = torch .int32 ,
924
- device = self .device )
925
-
926
- if spec_decode_metadata is None :
927
- # input_ids can be None for multimodal models.
928
- # We need to slice token_ids, positions, and hidden_states
929
- # because the eagle head does not use cuda graph and should
930
- # not include padding.
931
- target_token_ids = self .input_ids [:num_scheduled_tokens ]
932
- target_positions = positions [:num_scheduled_tokens ]
933
- target_hidden_states = hidden_states [:num_scheduled_tokens ]
934
- target_slot_mapping = attn_metadata .slot_mapping
935
- cu_num_tokens = attn_metadata .query_start_loc
936
- else :
937
- # TODO(woosuk): Refactor this.
938
- num_draft_tokens = spec_decode_metadata .num_draft_tokens
939
- num_rejected_tokens = [
940
- n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
941
- for i , n in enumerate (num_draft_tokens )
942
- ]
943
- num_rejected_tokens = torch .tensor (
944
- num_rejected_tokens ,
945
- dtype = torch .int32 ,
946
- device = self .device ,
947
- )
948
- cu_num_tokens , token_indices = self .drafter .prepare_inputs (
949
- attn_metadata .query_start_loc ,
950
- num_rejected_tokens ,
951
- )
952
- target_token_ids = self .input_ids [token_indices ]
953
- target_positions = positions [token_indices ]
954
- target_hidden_states = hidden_states [token_indices ]
955
- target_slot_mapping = attn_metadata .slot_mapping [token_indices ]
956
-
957
- draft_token_ids , draft_probs = self .drafter .propose (
958
- target_token_ids = target_token_ids ,
959
- target_positions = target_positions ,
960
- target_hidden_states = target_hidden_states ,
961
- target_slot_mapping = target_slot_mapping ,
962
- next_token_ids = next_token_ids ,
963
- cu_num_tokens = cu_num_tokens ,
964
- block_table = attn_metadata .block_tables ,
965
- sampling_metadata = sampling_metadata ,
966
- )
967
- spec_token_ids = draft_token_ids .tolist ()
968
- # TODO(woosuk): Cache draft_probs and use it for rejection sampling
969
- # in the next step.
970
- del draft_probs
984
+ spec_token_ids = self .get_spec_token_ids (
985
+ valid_sampled_token_ids ,
986
+ sampling_metadata ,
987
+ scheduler_output ,
988
+ spec_decode_metadata ,
989
+ positions ,
990
+ num_scheduled_tokens ,
991
+ hidden_states ,
992
+ attn_metadata ,
993
+ )
971
994
972
995
model_runner_output = ModelRunnerOutput (
973
996
req_ids = self .input_batch .req_ids ,
0 commit comments