Skip to content

Commit 6beab75

Browse files
authored
Update model_runner_v1.py
1 parent e53ba2c commit 6beab75

File tree

1 file changed

+2
-65
lines changed

1 file changed

+2
-65
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -869,72 +869,9 @@ def _get_spec_token_ids(
869869
spec_token_ids = self._generate_draft_token_ids(
870870
valid_sampled_token_ids, sampling_metadata)
871871
elif self.speculative_config.method == "eagle":
872-
raise NotImplementedError("eagle method for spec decode doesn't work on vllm-ascend currently")
873-
assert isinstance(self.drafter, EagleProposer)
874-
# TODO(woosuk): Refactor the loop.
875-
next_token_ids: list[int] = []
876-
for i, token_ids in enumerate(valid_sampled_token_ids):
877-
if token_ids:
878-
# Common case.
879-
next_token_id = token_ids[-1]
880-
else:
881-
# Partial prefill (rare case).
882-
# Get the next token id from the request state.
883-
req_id = self.input_batch.req_ids[i]
884-
req_state = self.requests[req_id]
885-
seq_len = (req_state.num_computed_tokens +
886-
scheduler_output.num_scheduled_tokens[req_id])
887-
next_token_id = req_state.get_token_id(seq_len)
888-
next_token_ids.append(next_token_id)
889-
next_token_ids = torch.tensor(next_token_ids,
890-
dtype=torch.int32,
891-
device=self.device)
892-
893-
if spec_decode_metadata is None:
894-
# input_ids can be None for multimodal models.
895-
# We need to slice token_ids, positions, and hidden_states
896-
# because the eagle head does not use cuda graph and should
897-
# not include padding.
898-
target_token_ids = self.input_ids[:num_scheduled_tokens]
899-
target_positions = positions[:num_scheduled_tokens]
900-
target_hidden_states = hidden_states[:num_scheduled_tokens]
901-
target_slot_mapping = attn_metadata.slot_mapping
902-
cu_num_tokens = attn_metadata.query_start_loc
903-
else:
904-
# TODO(woosuk): Refactor this.
905-
num_draft_tokens = spec_decode_metadata.num_draft_tokens
906-
num_rejected_tokens = [
907-
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
908-
for i, n in enumerate(num_draft_tokens)
909-
]
910-
num_rejected_tokens = torch.tensor(
911-
num_rejected_tokens,
912-
dtype=torch.int32,
913-
device=self.device,
914-
)
915-
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
916-
attn_metadata.query_start_loc,
917-
num_rejected_tokens,
918-
)
919-
target_token_ids = self.input_ids[token_indices]
920-
target_positions = positions[token_indices]
921-
target_hidden_states = hidden_states[token_indices]
922-
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
923-
924-
draft_token_ids, draft_probs = self.drafter.propose(
925-
target_token_ids=target_token_ids,
926-
target_positions=target_positions,
927-
target_hidden_states=target_hidden_states,
928-
target_slot_mapping=target_slot_mapping,
929-
next_token_ids=next_token_ids,
930-
cu_num_tokens=cu_num_tokens,
931-
block_table=attn_metadata.block_tables,
932-
sampling_metadata=sampling_metadata,
872+
raise NotImplementedError(
873+
"eagle method for spec decode doesn't work on vllm-ascend currently"
933874
)
934-
spec_token_ids = draft_token_ids.tolist()
935-
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
936-
# in the next step.
937-
del draft_probs
938875
return spec_token_ids
939876

940877
@torch.inference_mode()

0 commit comments

Comments
 (0)