@@ -850,7 +850,7 @@ def apply_grammar_bitmask(
850
850
)
851
851
return logits .to (self .device ).to (logits_dtype )
852
852
853
- def get_spec_token_ids (
853
+ def _get_spec_token_ids (
854
854
self ,
855
855
valid_sampled_token_ids : list [list [int ]],
856
856
sampling_metadata : SamplingMetadata ,
@@ -866,9 +866,10 @@ def get_spec_token_ids(
866
866
spec_token_ids = None
867
867
elif self .speculative_config .method == "ngram" :
868
868
assert isinstance (self .drafter , NgramProposer )
869
- spec_token_ids = self .generate_draft_token_ids (
869
+ spec_token_ids = self ._generate_draft_token_ids (
870
870
valid_sampled_token_ids , sampling_metadata )
871
871
elif self .speculative_config .method == "eagle" :
872
+ raise NotImplementedError ("eagle method for spec decode doesn't work on vllm-ascend currently" )
872
873
assert isinstance (self .drafter , EagleProposer )
873
874
# TODO(woosuk): Refactor the loop.
874
875
next_token_ids : list [int ] = []
@@ -1020,7 +1021,7 @@ def execute_model(
1020
1021
self .input_batch .vocab_size ,
1021
1022
)
1022
1023
1023
- spec_token_ids = self .get_spec_token_ids (
1024
+ spec_token_ids = self ._get_spec_token_ids (
1024
1025
valid_sampled_token_ids ,
1025
1026
sampling_metadata ,
1026
1027
scheduler_output ,
@@ -1390,7 +1391,7 @@ def capture_model(self) -> None:
1390
1391
logger .info ("Graph capturing finished in %.0f secs, took %.2f GiB" ,
1391
1392
elapsed_time , npu_graph_size / (1 << 30 ))
1392
1393
1393
- def generate_draft_token_ids (
1394
+ def _generate_draft_token_ids (
1394
1395
self ,
1395
1396
sampled_token_ids : list [list [int ]],
1396
1397
sampling_metadata : SamplingMetadata ,
0 commit comments