@@ -869,72 +869,9 @@ def _get_spec_token_ids(
869
869
spec_token_ids = self ._generate_draft_token_ids (
870
870
valid_sampled_token_ids , sampling_metadata )
871
871
elif self .speculative_config .method == "eagle" :
872
- raise NotImplementedError ("eagle method for spec decode doesn't work on vllm-ascend currently" )
873
- assert isinstance (self .drafter , EagleProposer )
874
- # TODO(woosuk): Refactor the loop.
875
- next_token_ids : list [int ] = []
876
- for i , token_ids in enumerate (valid_sampled_token_ids ):
877
- if token_ids :
878
- # Common case.
879
- next_token_id = token_ids [- 1 ]
880
- else :
881
- # Partial prefill (rare case).
882
- # Get the next token id from the request state.
883
- req_id = self .input_batch .req_ids [i ]
884
- req_state = self .requests [req_id ]
885
- seq_len = (req_state .num_computed_tokens +
886
- scheduler_output .num_scheduled_tokens [req_id ])
887
- next_token_id = req_state .get_token_id (seq_len )
888
- next_token_ids .append (next_token_id )
889
- next_token_ids = torch .tensor (next_token_ids ,
890
- dtype = torch .int32 ,
891
- device = self .device )
892
-
893
- if spec_decode_metadata is None :
894
- # input_ids can be None for multimodal models.
895
- # We need to slice token_ids, positions, and hidden_states
896
- # because the eagle head does not use cuda graph and should
897
- # not include padding.
898
- target_token_ids = self .input_ids [:num_scheduled_tokens ]
899
- target_positions = positions [:num_scheduled_tokens ]
900
- target_hidden_states = hidden_states [:num_scheduled_tokens ]
901
- target_slot_mapping = attn_metadata .slot_mapping
902
- cu_num_tokens = attn_metadata .query_start_loc
903
- else :
904
- # TODO(woosuk): Refactor this.
905
- num_draft_tokens = spec_decode_metadata .num_draft_tokens
906
- num_rejected_tokens = [
907
- n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
908
- for i , n in enumerate (num_draft_tokens )
909
- ]
910
- num_rejected_tokens = torch .tensor (
911
- num_rejected_tokens ,
912
- dtype = torch .int32 ,
913
- device = self .device ,
914
- )
915
- cu_num_tokens , token_indices = self .drafter .prepare_inputs (
916
- attn_metadata .query_start_loc ,
917
- num_rejected_tokens ,
918
- )
919
- target_token_ids = self .input_ids [token_indices ]
920
- target_positions = positions [token_indices ]
921
- target_hidden_states = hidden_states [token_indices ]
922
- target_slot_mapping = attn_metadata .slot_mapping [token_indices ]
923
-
924
- draft_token_ids , draft_probs = self .drafter .propose (
925
- target_token_ids = target_token_ids ,
926
- target_positions = target_positions ,
927
- target_hidden_states = target_hidden_states ,
928
- target_slot_mapping = target_slot_mapping ,
929
- next_token_ids = next_token_ids ,
930
- cu_num_tokens = cu_num_tokens ,
931
- block_table = attn_metadata .block_tables ,
932
- sampling_metadata = sampling_metadata ,
872
+ raise NotImplementedError (
873
+ "eagle method for spec decode doesn't work on vllm-ascend currently"
933
874
)
934
- spec_token_ids = draft_token_ids .tolist ()
935
- # TODO(woosuk): Cache draft_probs and use it for rejection sampling
936
- # in the next step.
937
- del draft_probs
938
875
return spec_token_ids
939
876
940
877
@torch .inference_mode ()
0 commit comments