Skip to content

[Bugfix] [Core] Fix zero temperature case (#5404 and part of #5898) #12802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,26 @@ def forward(
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)

# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
# Apply temperature scaling, special handling for zero-temperature case.
# Use float32 to apply temperature scaling in all cases.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
temperature = sampling_tensors.temperatures.unsqueeze(dim=1)
is_zero = (temperature == 0)

# Positive temperature path.
# Need to adjust denominator to avoid division by zero causing problems.
# Any zero temperature entries are multiplied by False (0).
# This means denominator adjustment never messes with things.
logits_p = (~is_zero) * logits / (temperature + is_zero)

# Zero temperature path.
# Any positive temperature entries are multiplied by False (0).
logits_z = is_zero * 1e9 * (logits == logits.max(dim=1,
keepdim=True)[0])

# Final logits is sum of both cases.
# Always one of them is zero since mutually exclusive.
logits = logits_p + logits_z

if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,6 @@ def from_sampling_metadata(
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
top_k = vocab_size if top_k == -1 else top_k
if temperature < _SAMPLING_EPS:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
or top_k != vocab_size):
do_top_p_top_k = True
Expand Down
Loading