Skip to content

Commit d56abb8

Browse files
authored
Update model_runner_v1.py
1 parent c794753 commit d56abb8

File tree

1 file changed

+97
-74
lines changed

1 file changed

+97
-74
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 97 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ def _calc_spec_decode_metadata(
732732
# [0, 1, 2, 5, 6, 9]
733733
target_logits_indices += arange
734734

735-
# TODO: Optimize the CPU -> GPU copy.
735+
# TODO: Optimize the CPU -> NPU copy.
736736
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
737737
self.device, non_blocking=True)
738738
logits_indices = torch.from_numpy(logits_indices).to(self.device,
@@ -811,6 +811,92 @@ def apply_grammar_bitmask(
811811
)
812812
return logits.to(self.device).to(logits_dtype)
813813

814+
def get_spec_token_ids(
815+
self,
816+
valid_sampled_token_ids: list[list[int]],
817+
sampling_metadata: SamplingMetadata,
818+
scheduler_output: "SchedulerOutput",
819+
spec_decode_metadata: SpecDecodeMetadata,
820+
positions: torch.Tensor,
821+
num_scheduled_tokens: int,
822+
hidden_states: torch.Tensor,
823+
attn_metadata: SpecDecodeMetadata,
824+
) -> list[list[int]]:
825+
if not self.use_spec_decode:
826+
# Speculative decoding is not enabled.
827+
spec_token_ids = None
828+
elif self.speculative_config.method == "ngram":
829+
assert isinstance(self.drafter, NgramProposer)
830+
spec_token_ids = self.generate_draft_token_ids(
831+
valid_sampled_token_ids, sampling_metadata)
832+
elif self.speculative_config.method == "eagle":
833+
assert isinstance(self.drafter, EagleProposer)
834+
# TODO(woosuk): Refactor the loop.
835+
next_token_ids: list[int] = []
836+
for i, token_ids in enumerate(valid_sampled_token_ids):
837+
if token_ids:
838+
# Common case.
839+
next_token_id = token_ids[-1]
840+
else:
841+
# Partial prefill (rare case).
842+
# Get the next token id from the request state.
843+
req_id = self.input_batch.req_ids[i]
844+
req_state = self.requests[req_id]
845+
seq_len = (req_state.num_computed_tokens +
846+
scheduler_output.num_scheduled_tokens[req_id])
847+
next_token_id = req_state.get_token_id(seq_len)
848+
next_token_ids.append(next_token_id)
849+
next_token_ids = torch.tensor(next_token_ids,
850+
dtype=torch.int32,
851+
device=self.device)
852+
853+
if spec_decode_metadata is None:
854+
# input_ids can be None for multimodal models.
855+
# We need to slice token_ids, positions, and hidden_states
856+
# because the eagle head does not use cuda graph and should
857+
# not include padding.
858+
target_token_ids = self.input_ids[:num_scheduled_tokens]
859+
target_positions = positions[:num_scheduled_tokens]
860+
target_hidden_states = hidden_states[:num_scheduled_tokens]
861+
target_slot_mapping = attn_metadata.slot_mapping
862+
cu_num_tokens = attn_metadata.query_start_loc
863+
else:
864+
# TODO(woosuk): Refactor this.
865+
num_draft_tokens = spec_decode_metadata.num_draft_tokens
866+
num_rejected_tokens = [
867+
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
868+
for i, n in enumerate(num_draft_tokens)
869+
]
870+
num_rejected_tokens = torch.tensor(
871+
num_rejected_tokens,
872+
dtype=torch.int32,
873+
device=self.device,
874+
)
875+
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
876+
attn_metadata.query_start_loc,
877+
num_rejected_tokens,
878+
)
879+
target_token_ids = self.input_ids[token_indices]
880+
target_positions = positions[token_indices]
881+
target_hidden_states = hidden_states[token_indices]
882+
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
883+
884+
draft_token_ids, draft_probs = self.drafter.propose(
885+
target_token_ids=target_token_ids,
886+
target_positions=target_positions,
887+
target_hidden_states=target_hidden_states,
888+
target_slot_mapping=target_slot_mapping,
889+
next_token_ids=next_token_ids,
890+
cu_num_tokens=cu_num_tokens,
891+
block_table=attn_metadata.block_tables,
892+
sampling_metadata=sampling_metadata,
893+
)
894+
spec_token_ids = draft_token_ids.tolist()
895+
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
896+
# in the next step.
897+
del draft_probs
898+
return spec_token_ids
899+
814900
@torch.inference_mode()
815901
def execute_model(
816902
self,
@@ -895,79 +981,16 @@ def execute_model(
895981
self.input_batch.vocab_size,
896982
)
897983

898-
if not self.use_spec_decode:
899-
# Speculative decoding is not enabled.
900-
spec_token_ids = None
901-
elif self.speculative_config.method == "ngram":
902-
assert isinstance(self.drafter, NgramProposer)
903-
spec_token_ids = self.generate_draft_token_ids(
904-
valid_sampled_token_ids, sampling_metadata)
905-
elif self.speculative_config.method == "eagle":
906-
assert isinstance(self.drafter, EagleProposer)
907-
# TODO(woosuk): Refactor the loop.
908-
next_token_ids: list[int] = []
909-
for i, token_ids in enumerate(valid_sampled_token_ids):
910-
if token_ids:
911-
# Common case.
912-
next_token_id = token_ids[-1]
913-
else:
914-
# Partial prefill (rare case).
915-
# Get the next token id from the request state.
916-
req_id = self.input_batch.req_ids[i]
917-
req_state = self.requests[req_id]
918-
seq_len = (req_state.num_computed_tokens +
919-
scheduler_output.num_scheduled_tokens[req_id])
920-
next_token_id = req_state.get_token_id(seq_len)
921-
next_token_ids.append(next_token_id)
922-
next_token_ids = torch.tensor(next_token_ids,
923-
dtype=torch.int32,
924-
device=self.device)
925-
926-
if spec_decode_metadata is None:
927-
# input_ids can be None for multimodal models.
928-
# We need to slice token_ids, positions, and hidden_states
929-
# because the eagle head does not use cuda graph and should
930-
# not include padding.
931-
target_token_ids = self.input_ids[:num_scheduled_tokens]
932-
target_positions = positions[:num_scheduled_tokens]
933-
target_hidden_states = hidden_states[:num_scheduled_tokens]
934-
target_slot_mapping = attn_metadata.slot_mapping
935-
cu_num_tokens = attn_metadata.query_start_loc
936-
else:
937-
# TODO(woosuk): Refactor this.
938-
num_draft_tokens = spec_decode_metadata.num_draft_tokens
939-
num_rejected_tokens = [
940-
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
941-
for i, n in enumerate(num_draft_tokens)
942-
]
943-
num_rejected_tokens = torch.tensor(
944-
num_rejected_tokens,
945-
dtype=torch.int32,
946-
device=self.device,
947-
)
948-
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
949-
attn_metadata.query_start_loc,
950-
num_rejected_tokens,
951-
)
952-
target_token_ids = self.input_ids[token_indices]
953-
target_positions = positions[token_indices]
954-
target_hidden_states = hidden_states[token_indices]
955-
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
956-
957-
draft_token_ids, draft_probs = self.drafter.propose(
958-
target_token_ids=target_token_ids,
959-
target_positions=target_positions,
960-
target_hidden_states=target_hidden_states,
961-
target_slot_mapping=target_slot_mapping,
962-
next_token_ids=next_token_ids,
963-
cu_num_tokens=cu_num_tokens,
964-
block_table=attn_metadata.block_tables,
965-
sampling_metadata=sampling_metadata,
966-
)
967-
spec_token_ids = draft_token_ids.tolist()
968-
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
969-
# in the next step.
970-
del draft_probs
984+
spec_token_ids = self.get_spec_token_ids(
985+
valid_sampled_token_ids,
986+
sampling_metadata,
987+
scheduler_output,
988+
spec_decode_metadata,
989+
positions,
990+
num_scheduled_tokens,
991+
hidden_states,
992+
attn_metadata,
993+
)
971994

972995
model_runner_output = ModelRunnerOutput(
973996
req_ids=self.input_batch.req_ids,

0 commit comments

Comments
 (0)