From 7454c0e0db390bf7330ad19dd6fb225e152c5ca2 Mon Sep 17 00:00:00 2001 From: Artur Fierka Date: Mon, 23 Jun 2025 10:26:32 +0200 Subject: [PATCH] Revert "Remove intel implementation of top-p/top-k sampling method (#1243)" This reverts commit 13a2b7373fb432a0c9257d1f4a4294fa5bd4183b. --- tests/samplers/test_sampler.py | 62 ++++++++++++- vllm/model_executor/layers/sampler.py | 128 +++++++++++++++++++++++++- 2 files changed, 187 insertions(+), 3 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 6924aba1157..844b2140d2c 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -11,7 +11,7 @@ from transformers import GenerationConfig, GenerationMixin import vllm.envs as envs -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import ApplyToppTopkScalar, Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata @@ -764,3 +764,63 @@ def test_sampler_include_gpu_probs_tensor(device: str): assert sampler_output.sampled_token_probs is not None assert sampler_output.logprobs is not None assert sampler_output.sampled_token_ids is not None + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_topk_topk_scalar(): + obj1 = ApplyToppTopkScalar(2) + assert ApplyToppTopkScalar._padded_k == 0 + x = torch.tensor([[9, 9, 8, 8, 8, 8, 7, 7, 7.0], + [10, 10, 9, 9, 9, 8, 5, 5, 5]]) + + retval1 = obj1(x, p=0.9, k=5) + ninf = -float("inf") + expected1 = torch.tensor([[9., 9., 8., 8., 8., 8., ninf, ninf, ninf], + [10., 10., 9., 9., 9., ninf, ninf, ninf, ninf]]) + assert torch.all(retval1 == expected1).item() + assert ApplyToppTopkScalar._padded_k == 9 + + obj2 = ApplyToppTopkScalar(2) + assert obj2._padded_k == 9 + + x = torch.tensor([[2, 2, 9, 9, 2, 2, 1, 1, 1.0], + [10, 9, 9, 5, 9, 9, 5, 9, 10]]) + retval2 = obj2(x, p=0.9, k=5) + expected2 = torch.tensor( + [[ninf, ninf, 9., 9., ninf, ninf, ninf, ninf, ninf], + [10., ninf, 9., ninf, 9., 9., ninf, 9., 10.]]) + assert torch.all(retval2 == expected2).item() + assert obj2._padded_k == 9 + + retval3 = obj2(x, p=1.0, k=5) + expected3 = torch.tensor([[2., 2., 9., 9., 2., 2., ninf, ninf, ninf], + [10., 9., 9., ninf, 9., 9., ninf, 9., 10.]]) + + assert torch.all(retval3 == expected3).item() + + # this should not be done in general, doing it here for testing purposes + ApplyToppTopkScalar._padded_k = 0 + x = torch.tensor([[1, 1, 1, 9, 8, 1, 1, 1, 1.0], + [2, 1, 2, 2, 1, 1, 1, 1, 1]]) + obj3 = ApplyToppTopkScalar(2) + retval4 = obj3(x, p=0.9, k=2) + expected4 = torch.tensor( + [[ninf, ninf, ninf, 9., 8., ninf, ninf, ninf, ninf], + [2., ninf, 2., 2., ninf, ninf, ninf, ninf, ninf]]) + assert torch.all(retval4 == expected4).item() + assert obj3._padded_k == 4 + y = torch.tensor([[8, 8, 8, 9, 8, 1, 1, 1, 1.0], + [2, 1, 2, 2, 1, 1, 1, 1, 1]]) + retval5 = obj3(y, p=0.9, k=2) + assert obj3._padded_k == 8 + expected5 = torch.tensor([[8., 8., 8., 9., 8., ninf, ninf, ninf, ninf], + [2., ninf, 2., 2., ninf, ninf, ninf, ninf, + ninf]]) + assert torch.all(retval5 == expected5).item() + y = torch.tensor([[8, 8, 8, 9, 8, 8, 1, 1, 1.0], + [2, 1, 2, 2, 3, 1, 1, 1, 1]]) + retval6 = obj3(y, p=0.9, k=2) + expected6 = torch.tensor([[8., 8., 8., 9., 8., 8., ninf, ninf, ninf], + [2., ninf, 2., 2., 3., ninf, ninf, ninf, ninf]]) + assert torch.all(retval6 == expected6).item() + assert obj3._padded_k == 8 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d289eddf111..7963810a767 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """A layer that samples the next tokens from the model's outputs.""" import itertools +import math +import os import warnings from dataclasses import dataclass from importlib.util import find_spec @@ -217,6 +219,10 @@ def _init_sampling_tensors( self._do_penalties = do_penalties self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p + self._top_k_scalar = top_k_scalar + self._top_p_scalar = top_p_scalar + + self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) def forward( self, @@ -276,8 +282,14 @@ def forward( logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + # If we have a scalar p and k, we can use the optimized version. + if self._top_k_scalar and self._top_p_scalar: + logits = self._apply_top_k_top_p_opt(logits, + self._top_p_scalar, + self._top_k_scalar) + else: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) @@ -359,6 +371,118 @@ def _get_bin_counts_and_mask( return bin_counts, mask +class ApplyToppTopkScalar: + """ + The original implementation of _apply_top_k_top_p is more general + as it uses vector topp, topk + However in a lot of cases, topp and topk is same for all batch elements + For such "scalar" topp, topk cases, we can use this class + + The main optimizations in this class is: + Use topk instead of sort, which is much faster especially for small k. + However just using topk might not suffice in cases as shown below + Consider a tensor: 9 9 8 8 8 8 7 7 7 + Topk, with k=5, on this yields 9 9 8 8 8 + The value "8" is on the boundary, hence the last "8" gets snipped off + However the original implementation accepts all the "8"s, + so it should output: + 9 9 8 8 8 8 (6 values, even though k=5) + To ensure these semantics, we perform topk with _padded_k elements + If we find more boundary elements left over, + then we keep incrementing _padded_k + and in future calls use the expanded value of __padded_k + + The increments to _padded_k should be done + with value > 1 to prevent excessive recompilations + due to dynamic shapes (the output shape of the topk) + + The main logic of this is in __call__ + This is a class instead of a function, just to keep track of + the monotonic non-decreasing state _padded_k + + To enable the duplicates that are outside of kth border, + set VLLM_HANDLE_TOPK_DUPLICATES to 1 or true. + """ + _padded_k = 0 + _handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES', + '0').lower() in ['1', 'true'] + + def __init__(self, increment: int): + self._increment = increment + + def __call__(self, logits: torch.Tensor, p: float, k: int): + if k == 1 and not ApplyToppTopkScalar._handle_duplicates: + new_logits = torch.full(logits.shape, + -float("inf"), + device=logits.device) + vals, idx = torch.max(logits, keepdim=True, dim=1) + new_logits.scatter_(1, idx, vals.to(new_logits.dtype)) + return new_logits + + if k > ApplyToppTopkScalar._padded_k: + ApplyToppTopkScalar._padded_k = min(k + self._increment, + logits.shape[1]) + + vals, idx = torch.topk(logits, + k=ApplyToppTopkScalar._padded_k, + dim=1, + sorted=True) + + # this "if" checks if we have bucketed so much that + # we have padded k upto shape of logits + if self._handle_duplicates and \ + ApplyToppTopkScalar._padded_k != logits.shape[1]: + smallest_of_top_k = vals[:, k - 1] + num_duplicates_of_smallest_of_topk = torch.sum( + logits == smallest_of_top_k.unsqueeze(1), 1) + max_num_duplicates_of_smallest_of_topk = torch.max( + num_duplicates_of_smallest_of_topk).item() + + # there are n repeats for a border + # (border meaning the smallest value of the top k). + # we do not know if only 1 or 2 or (n-1) + # of them lie outside the kth border, + # so we choose to conservatively increase by n-1 + # when num_duplicates > _padded_k - k + if max_num_duplicates_of_smallest_of_topk - 1 > ( + ApplyToppTopkScalar._padded_k - k): + incr = int( + math.ceil((max_num_duplicates_of_smallest_of_topk - 1) / + self._increment) * self._increment) + # this while loop should be traversed at most twice, + # because we dont increment by self._increment and retry + # instead we compute incr in one go + ApplyToppTopkScalar._padded_k = min( + ApplyToppTopkScalar._padded_k + incr, logits.shape[1]) + + # recompute topk with expanded padded_k + vals, idx = torch.topk(logits, + k=ApplyToppTopkScalar._padded_k, + dim=1, + sorted=True) + + idx = torch.fliplr(idx) + vals = torch.fliplr(vals) + + top_k_smallest_val_idx = vals.size(1) - k + top_k_mask = vals[:, top_k_smallest_val_idx].unsqueeze(1) + top_k_mask = vals < top_k_mask + vals.masked_fill_(top_k_mask, -float("inf")) + + probs_sort = vals.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= (1 - p) + top_p_mask[:, -1] = False + vals.masked_fill_(top_p_mask, -float("inf")) + + new_logits = torch.full(logits.shape, + -float("inf"), + device=logits.device) + new_logits.scatter_(1, idx, vals.to(new_logits.dtype)) + + return new_logits + + def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata,