From ddd76204f66f65c1476c61e3569e2d1ee01bdc89 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 14 Feb 2025 15:45:30 -0800 Subject: [PATCH] [V0][Sampler] Use raw logits for greedy argmax To hopefully avoid some of the reported precision-related nondeterminism. Signed-off-by: Nick Hill --- vllm/model_executor/layers/sampler.py | 42 ++++++--------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6af734be5e98..b7f84e99bfd7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -286,6 +286,7 @@ def forward( # Sample the next tokens. maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( + logits, probs, logprobs, sampling_metadata, @@ -697,7 +698,8 @@ def get_pythonized_sample_results( ] -def _sample_with_torch( +def _sample( + logits: torch.Tensor, probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -715,6 +717,11 @@ def _sample_with_torch( * Perform GPU-side sampling computation * Defer Pythonization & preserve GPU-side tensors required for Pythonization + + Returns: + (next_token_ids, parent_seq_ids) for each seq group in a batch. + If sampling is skipped, it returns ([], []) + sampled_token_ids_tensor: A tensor of sampled token ids. ''' categorized_seq_group_ids: Dict[SamplingType, List[int]] = { @@ -755,8 +762,7 @@ def _sample_with_torch( sample_metadata[sampling_type] = (seq_group_id, seq_groups) long_sample_indices = sample_indices.long() if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[long_sample_indices], - dim=-1) + greedy_samples = torch.argmax(logits[long_sample_indices], dim=-1) if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. @@ -830,36 +836,6 @@ def _sample_with_torch( ) -def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - """ - Args: - probs: (num_query_tokens_in_batch, num_vocab) - logprobs: (num_query_tokens_in_batch, num_vocab) - sampling_metadata: The metadata for a batch for sampling. - sampling_tensors: Tensors that include sampling related metadata. - - Returns: - (next_token_ids, parent_seq_ids) for each seq group in a batch. - If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. - """ - return _sample_with_torch( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=include_gpu_probs_tensor, - modify_greedy_probs=modify_greedy_probs, - ) - - def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ This function calculates the ranks of the chosen tokens in a logprob tensor.