Skip to content

Commit d3cd576

Browse files
authored
Update model_runner_v1.py
1 parent cd5faa9 commit d3cd576

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def apply_grammar_bitmask(
850850
)
851851
return logits.to(self.device).to(logits_dtype)
852852

853-
def get_spec_token_ids(
853+
def _get_spec_token_ids(
854854
self,
855855
valid_sampled_token_ids: list[list[int]],
856856
sampling_metadata: SamplingMetadata,
@@ -866,9 +866,10 @@ def get_spec_token_ids(
866866
spec_token_ids = None
867867
elif self.speculative_config.method == "ngram":
868868
assert isinstance(self.drafter, NgramProposer)
869-
spec_token_ids = self.generate_draft_token_ids(
869+
spec_token_ids = self._generate_draft_token_ids(
870870
valid_sampled_token_ids, sampling_metadata)
871871
elif self.speculative_config.method == "eagle":
872+
raise NotImplementedError("eagle method for spec decode doesn't work on vllm-ascend currently")
872873
assert isinstance(self.drafter, EagleProposer)
873874
# TODO(woosuk): Refactor the loop.
874875
next_token_ids: list[int] = []
@@ -1020,7 +1021,7 @@ def execute_model(
10201021
self.input_batch.vocab_size,
10211022
)
10221023

1023-
spec_token_ids = self.get_spec_token_ids(
1024+
spec_token_ids = self._get_spec_token_ids(
10241025
valid_sampled_token_ids,
10251026
sampling_metadata,
10261027
scheduler_output,
@@ -1390,7 +1391,7 @@ def capture_model(self) -> None:
13901391
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
13911392
elapsed_time, npu_graph_size / (1 << 30))
13921393

1393-
def generate_draft_token_ids(
1394+
def _generate_draft_token_ids(
13941395
self,
13951396
sampled_token_ids: list[list[int]],
13961397
sampling_metadata: SamplingMetadata,

0 commit comments

Comments
 (0)