diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3cf7fde5cd0e..9061a64db57c 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -226,7 +226,7 @@ def rejection_sample( is_greedy, max_spec_len, vocab_size, - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, num_warps=1, ) return output_token_ids @@ -423,7 +423,7 @@ def sample_recovered_tokens( q, vocab_size, triton.next_power_of_2(vocab_size), - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, ) return recovered_token_ids @@ -490,7 +490,7 @@ def rejection_random_sample_kernel( is_greedy_ptr, # [batch_size] max_spec_len, vocab_size, - IS_NGRAM: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) is_greedy = tl.load(is_greedy_ptr + req_idx) @@ -509,7 +509,7 @@ def rejection_random_sample_kernel( for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if IS_NGRAM: + if NO_DRAFT_PROBS: draft_prob = 1 else: draft_prob = tl.load(draft_probs_ptr + @@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel( q_ptr, # [batch_size, vocab_size] vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, - IS_NGRAM: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: @@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel( return vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) - if IS_NGRAM: + if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) @@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel( recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - if IS_NGRAM: + if NO_DRAFT_PROBS: # Restore the original probability. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3efafa8f0b1f..95f0c067d406 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -51,7 +51,7 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 @@ -94,17 +94,15 @@ def propose( ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids, draft_probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] and [batch_size, 1, vocab_size] - return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) + # [batch_size, 1] + return draft_token_ids.view(-1, 1) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - draft_probs_list = [draft_probs] positions = target_positions[last_token_indices] hidden_states = sample_hidden_states @@ -159,16 +157,12 @@ def propose( positions=clamped_positions, ) logits = self.model.compute_logits(hidden_states, None) - draft_token_ids, probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) - draft_probs_list.append(probs) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - # [batch_size, num_speculative_tokens, vocab_size] - draft_probs = torch.stack(draft_probs_list, dim=1) - return draft_token_ids, draft_probs + return draft_token_ids @staticmethod def prepare_inputs( @@ -238,6 +232,10 @@ def load_model(self, target_model: nn.Module) -> None: self.model.lm_head = target_model.lm_head +# NOTE(woosuk): Currently, the below code is not used and we always use argmax +# to sample the draft tokens. We will use this after we find a way to manage +# the draft prob tensor. +# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. # FIXME(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. def compute_probs_and_sample_next_token( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4ecf72b56ef6..c12e4fd555d2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1230,7 +1230,7 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - draft_token_ids, draft_probs = self.drafter.propose( + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1241,9 +1241,6 @@ def execute_model( sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - # TODO(woosuk): Cache draft_probs and use it for rejection sampling - # in the next step. - del draft_probs # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group():