Skip to content

Revert "Remove intel implementation of top-p/top-k sampling method" #1466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers import GenerationConfig, GenerationMixin

import vllm.envs as envs
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import ApplyToppTopkScalar, Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Expand Down Expand Up @@ -764,3 +764,63 @@ def test_sampler_include_gpu_probs_tensor(device: str):
assert sampler_output.sampled_token_probs is not None
assert sampler_output.logprobs is not None
assert sampler_output.sampled_token_ids is not None


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_topk_topk_scalar():
obj1 = ApplyToppTopkScalar(2)
assert ApplyToppTopkScalar._padded_k == 0
x = torch.tensor([[9, 9, 8, 8, 8, 8, 7, 7, 7.0],
[10, 10, 9, 9, 9, 8, 5, 5, 5]])

retval1 = obj1(x, p=0.9, k=5)
ninf = -float("inf")
expected1 = torch.tensor([[9., 9., 8., 8., 8., 8., ninf, ninf, ninf],
[10., 10., 9., 9., 9., ninf, ninf, ninf, ninf]])
assert torch.all(retval1 == expected1).item()
assert ApplyToppTopkScalar._padded_k == 9

obj2 = ApplyToppTopkScalar(2)
assert obj2._padded_k == 9

x = torch.tensor([[2, 2, 9, 9, 2, 2, 1, 1, 1.0],
[10, 9, 9, 5, 9, 9, 5, 9, 10]])
retval2 = obj2(x, p=0.9, k=5)
expected2 = torch.tensor(
[[ninf, ninf, 9., 9., ninf, ninf, ninf, ninf, ninf],
[10., ninf, 9., ninf, 9., 9., ninf, 9., 10.]])
assert torch.all(retval2 == expected2).item()
assert obj2._padded_k == 9

retval3 = obj2(x, p=1.0, k=5)
expected3 = torch.tensor([[2., 2., 9., 9., 2., 2., ninf, ninf, ninf],
[10., 9., 9., ninf, 9., 9., ninf, 9., 10.]])

assert torch.all(retval3 == expected3).item()

# this should not be done in general, doing it here for testing purposes
ApplyToppTopkScalar._padded_k = 0
x = torch.tensor([[1, 1, 1, 9, 8, 1, 1, 1, 1.0],
[2, 1, 2, 2, 1, 1, 1, 1, 1]])
obj3 = ApplyToppTopkScalar(2)
retval4 = obj3(x, p=0.9, k=2)
expected4 = torch.tensor(
[[ninf, ninf, ninf, 9., 8., ninf, ninf, ninf, ninf],
[2., ninf, 2., 2., ninf, ninf, ninf, ninf, ninf]])
assert torch.all(retval4 == expected4).item()
assert obj3._padded_k == 4
y = torch.tensor([[8, 8, 8, 9, 8, 1, 1, 1, 1.0],
[2, 1, 2, 2, 1, 1, 1, 1, 1]])
retval5 = obj3(y, p=0.9, k=2)
assert obj3._padded_k == 8
expected5 = torch.tensor([[8., 8., 8., 9., 8., ninf, ninf, ninf, ninf],
[2., ninf, 2., 2., ninf, ninf, ninf, ninf,
ninf]])
assert torch.all(retval5 == expected5).item()
y = torch.tensor([[8, 8, 8, 9, 8, 8, 1, 1, 1.0],
[2, 1, 2, 2, 3, 1, 1, 1, 1]])
retval6 = obj3(y, p=0.9, k=2)
expected6 = torch.tensor([[8., 8., 8., 9., 8., 8., ninf, ninf, ninf],
[2., ninf, 2., 2., 3., ninf, ninf, ninf, ninf]])
assert torch.all(retval6 == expected6).item()
assert obj3._padded_k == 8
128 changes: 126 additions & 2 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import math
import os
import warnings
from dataclasses import dataclass
from importlib.util import find_spec
Expand Down Expand Up @@ -217,6 +219,10 @@ def _init_sampling_tensors(
self._do_penalties = do_penalties
self._do_top_p_top_k = do_top_p_top_k
self._do_min_p = do_min_p
self._top_k_scalar = top_k_scalar
self._top_p_scalar = top_p_scalar

self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)

def forward(
self,
Expand Down Expand Up @@ -276,8 +282,14 @@ def forward(
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
# If we have a scalar p and k, we can use the optimized version.
if self._top_k_scalar and self._top_p_scalar:
logits = self._apply_top_k_top_p_opt(logits,
self._top_p_scalar,
self._top_k_scalar)
else:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
Expand Down Expand Up @@ -359,6 +371,118 @@ def _get_bin_counts_and_mask(
return bin_counts, mask


class ApplyToppTopkScalar:
"""
The original implementation of _apply_top_k_top_p is more general
as it uses vector topp, topk
However in a lot of cases, topp and topk is same for all batch elements
For such "scalar" topp, topk cases, we can use this class

The main optimizations in this class is:
Use topk instead of sort, which is much faster especially for small k.
However just using topk might not suffice in cases as shown below
Consider a tensor: 9 9 8 8 8 8 7 7 7
Topk, with k=5, on this yields 9 9 8 8 8
The value "8" is on the boundary, hence the last "8" gets snipped off
However the original implementation accepts all the "8"s,
so it should output:
9 9 8 8 8 8 (6 values, even though k=5)
To ensure these semantics, we perform topk with _padded_k elements
If we find more boundary elements left over,
then we keep incrementing _padded_k
and in future calls use the expanded value of __padded_k

The increments to _padded_k should be done
with value > 1 to prevent excessive recompilations
due to dynamic shapes (the output shape of the topk)

The main logic of this is in __call__
This is a class instead of a function, just to keep track of
the monotonic non-decreasing state _padded_k

To enable the duplicates that are outside of kth border,
set VLLM_HANDLE_TOPK_DUPLICATES to 1 or true.
"""
_padded_k = 0
_handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES',
'0').lower() in ['1', 'true']

def __init__(self, increment: int):
self._increment = increment

def __call__(self, logits: torch.Tensor, p: float, k: int):
if k == 1 and not ApplyToppTopkScalar._handle_duplicates:
new_logits = torch.full(logits.shape,
-float("inf"),
device=logits.device)
vals, idx = torch.max(logits, keepdim=True, dim=1)
new_logits.scatter_(1, idx, vals.to(new_logits.dtype))
return new_logits

if k > ApplyToppTopkScalar._padded_k:
ApplyToppTopkScalar._padded_k = min(k + self._increment,
logits.shape[1])

vals, idx = torch.topk(logits,
k=ApplyToppTopkScalar._padded_k,
dim=1,
sorted=True)

# this "if" checks if we have bucketed so much that
# we have padded k upto shape of logits
if self._handle_duplicates and \
ApplyToppTopkScalar._padded_k != logits.shape[1]:
smallest_of_top_k = vals[:, k - 1]
num_duplicates_of_smallest_of_topk = torch.sum(
logits == smallest_of_top_k.unsqueeze(1), 1)
max_num_duplicates_of_smallest_of_topk = torch.max(
num_duplicates_of_smallest_of_topk).item()

# there are n repeats for a border
# (border meaning the smallest value of the top k).
# we do not know if only 1 or 2 or (n-1)
# of them lie outside the kth border,
# so we choose to conservatively increase by n-1
# when num_duplicates > _padded_k - k
if max_num_duplicates_of_smallest_of_topk - 1 > (
ApplyToppTopkScalar._padded_k - k):
incr = int(
math.ceil((max_num_duplicates_of_smallest_of_topk - 1) /
self._increment) * self._increment)
# this while loop should be traversed at most twice,
# because we dont increment by self._increment and retry
# instead we compute incr in one go
ApplyToppTopkScalar._padded_k = min(
ApplyToppTopkScalar._padded_k + incr, logits.shape[1])

# recompute topk with expanded padded_k
vals, idx = torch.topk(logits,
k=ApplyToppTopkScalar._padded_k,
dim=1,
sorted=True)

idx = torch.fliplr(idx)
vals = torch.fliplr(vals)

top_k_smallest_val_idx = vals.size(1) - k
top_k_mask = vals[:, top_k_smallest_val_idx].unsqueeze(1)
top_k_mask = vals < top_k_mask
vals.masked_fill_(top_k_mask, -float("inf"))

probs_sort = vals.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= (1 - p)
top_p_mask[:, -1] = False
vals.masked_fill_(top_p_mask, -float("inf"))

new_logits = torch.full(logits.shape,
-float("inf"),
device=logits.device)
new_logits.scatter_(1, idx, vals.to(new_logits.dtype))

return new_logits


def _apply_min_tokens_penalty(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
Expand Down