diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6af734be5e98..e9b5bfe3e84a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -266,10 +266,26 @@ def forward( sampling_tensors.frequency_penalties, sampling_tensors.repetition_penalties) - # Use float32 to apply temperature scaling. - # Use in-place division to avoid creating a new tensor. + # Apply temperature scaling, special handling for zero-temperature case. + # Use float32 to apply temperature scaling in all cases. logits = logits.to(torch.float) - logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + temperature = sampling_tensors.temperatures.unsqueeze(dim=1) + is_zero = (temperature == 0) + + # Positive temperature path. + # Need to adjust denominator to avoid division by zero causing problems. + # Any zero temperature entries are multiplied by False (0). + # This means denominator adjustment never messes with things. + logits_p = (~is_zero) * logits / (temperature + is_zero) + + # Zero temperature path. + # Any positive temperature entries are multiplied by False (0). + logits_z = is_zero * 1e9 * (logits == logits.max(dim=1, + keepdim=True)[0]) + + # Final logits is sum of both cases. + # Always one of them is zero since mutually exclusive. + logits = logits_p + logits_z if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 0a580a4e907d..ac044cbb4ec5 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -417,11 +417,6 @@ def from_sampling_metadata( # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) top_k = vocab_size if top_k == -1 else top_k - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS or top_k != vocab_size): do_top_p_top_k = True