From eb39b117f61629c9ccabda3c24ac9860425c2ce7 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Sat, 15 Feb 2025 21:24:01 +0800 Subject: [PATCH 1/6] optimize logit_bias Signed-off-by: Xu Song --- vllm/engine/llm_engine.py | 3 ++- vllm/entrypoints/openai/logits_processors.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2e5bc75c6db3..79a48c9d65a7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2003,7 +2003,8 @@ def _build_logits_processors( processors = get_openai_logits_processors( logit_bias=sampling_params.logit_bias, allowed_token_ids=sampling_params.allowed_token_ids, - tokenizer=tokenizer) + tokenizer=tokenizer, + dtype=self.model_config.dtype) logits_processors.extend(processors) # Unset so these don't get passed down to the model diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 41e5eef40eaf..893fafedbd29 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -43,12 +43,12 @@ def _get_allowed_token_ids_logits_processor( def logit_bias_logits_processor( - logit_bias: Dict[int, float], + logit_bias: Dict[str, torch.Tensor], token_ids: List[int], logits: torch.Tensor, ) -> torch.Tensor: - for token_id, bias in logit_bias.items(): - logits[token_id] += bias + logits.index_add_(0, logit_bias["index"].to(logits.device), + logit_bias["value"].to(logits.device)) return logits @@ -56,6 +56,7 @@ def get_logits_processors( logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], allowed_token_ids: Optional[List[int]], tokenizer: AnyTokenizer, + dtype: Union[str, torch.dtype], ) -> List[LogitsProcessor]: logits_processors: List[LogitsProcessor] = [] if logit_bias: @@ -77,6 +78,11 @@ def get_logits_processors( raise ValueError(f"token_id {token_id} in logit_bias contains " "out-of-vocab token id") + clamped_logit_bias = { + "index": torch.tensor(list(clamped_logit_bias.keys())), + "value": torch.tensor(list(clamped_logit_bias.values()), + dtype=dtype) + } logits_processors.append( partial(logit_bias_logits_processor, clamped_logit_bias)) From a50b0ead8084692522ce41c65ee7ada4002ed29d Mon Sep 17 00:00:00 2001 From: Xu Song Date: Sat, 15 Feb 2025 22:43:07 +0800 Subject: [PATCH 2/6] optimize logit_bias Signed-off-by: Xu Song --- vllm/entrypoints/openai/logits_processors.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 893fafedbd29..0afa1b5ff59a 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -63,25 +63,24 @@ def get_logits_processors( try: # Convert token_id to integer # Clamp the bias between -100 and 100 per OpenAI API spec - clamped_logit_bias: Dict[int, float] = { - int(token_id): min(100.0, max(-100.0, bias)) - for token_id, bias in logit_bias.items() - } + logit_bias_index = [int(token_id) for token_id in logit_bias] + logit_bias_value = [ + min(100.0, max(-100.0, bias)) for bias in logit_bias.values() + ] except ValueError as exc: raise ValueError( "Found token_id in logit_bias that is not " "an integer or string representing an integer") from exc # Check if token_id is within the vocab size - for token_id, bias in clamped_logit_bias.items(): + for token_id in logit_bias_index: if token_id < 0 or token_id >= len(tokenizer): raise ValueError(f"token_id {token_id} in logit_bias contains " "out-of-vocab token id") - clamped_logit_bias = { - "index": torch.tensor(list(clamped_logit_bias.keys())), - "value": torch.tensor(list(clamped_logit_bias.values()), - dtype=dtype) + clamped_logit_bias: Dict[str, torch.Tensor] = { + "index": torch.tensor(logit_bias_index), + "value": torch.tensor(logit_bias_value, dtype=dtype) } logits_processors.append( partial(logit_bias_logits_processor, clamped_logit_bias)) From 62a74e3d1bf3b9b34d17cd5fb46e36f0be756d89 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Mon, 17 Feb 2025 10:36:08 +0800 Subject: [PATCH 3/6] avoid duplicated tensor copy Signed-off-by: Xu Song --- vllm/entrypoints/openai/logits_processors.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 0afa1b5ff59a..43a43b7e62aa 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -47,8 +47,10 @@ def logit_bias_logits_processor( token_ids: List[int], logits: torch.Tensor, ) -> torch.Tensor: - logits.index_add_(0, logit_bias["index"].to(logits.device), - logit_bias["value"].to(logits.device)) + if logit_bias["value"].device != logits.device: + logit_bias["index"] = logit_bias["index"].to(logits.device) + logit_bias["value"] = logit_bias["value"].to(logits.device) + logits.index_add_(0, logit_bias["index"], logit_bias["value"]) return logits From 0a2b889c7ada8a2c9171fa3e088bd1bec6bf06aa Mon Sep 17 00:00:00 2001 From: Xu Song Date: Mon, 17 Feb 2025 15:38:33 +0800 Subject: [PATCH 4/6] add lru_cache across different requests Signed-off-by: Xu Song --- vllm/entrypoints/openai/logits_processors.py | 95 ++++++++++++-------- 1 file changed, 60 insertions(+), 35 deletions(-) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 43a43b7e62aa..a876dd8f571d 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from functools import lru_cache, partial -from typing import Dict, FrozenSet, Iterable, List, Optional, Union +from functools import lru_cache +from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple, Union import torch @@ -42,16 +42,61 @@ def _get_allowed_token_ids_logits_processor( return AllowedTokenIdsLogitsProcessor(allowed_token_ids) -def logit_bias_logits_processor( - logit_bias: Dict[str, torch.Tensor], - token_ids: List[int], - logits: torch.Tensor, -) -> torch.Tensor: - if logit_bias["value"].device != logits.device: - logit_bias["index"] = logit_bias["index"].to(logits.device) - logit_bias["value"] = logit_bias["value"].to(logits.device) - logits.index_add_(0, logit_bias["index"], logit_bias["value"]) - return logits +class LogitBiasLogitsProcessor: + """Logits processor for applying biases to logits. + It lets you control whether the model is more or less likely to + generate a specific token. + """ + + def __init__(self, logit_bias_index: List[int], + logit_bias_value: List[float], dtype: Union[str, + torch.dtype]): + self.logit_bias_index: torch.Tensor = torch.tensor(logit_bias_index) + self.logit_bias_value: torch.Tensor = torch.tensor(logit_bias_value, + dtype=dtype) + + def __call__( + self, + token_ids: List[int], + logits: torch.Tensor, + ) -> torch.Tensor: + if self.logit_bias_value.device != logits.device: + self.logit_bias_index = self.logit_bias_index.to(logits.device) + self.logit_bias_value = self.logit_bias_value.to(logits.device) + logits.index_add_(0, self.logit_bias_index, self.logit_bias_value) + return logits + + +@lru_cache(maxsize=32) +def _get_logit_bias_logits_processor( + logit_bias_index: Optional[Union[Tuple[int], Tuple[str]]], + logit_bias_value: Optional[Tuple[float]], + vocab_size: int, + dtype: Union[str, torch.dtype], +) -> LogitsProcessor: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + logit_bias_index: List[int] = [ + int(token_id) for token_id in logit_bias_index + ] + logit_bias_value: List[float] = [ + min(100.0, max(-100.0, bias)) for bias in logit_bias_value + ] + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + + # Check if token_id is within the vocab size + for token_id in logit_bias_index: + if token_id < 0 or token_id >= vocab_size: + raise ValueError(f"token_id {token_id} in logit_bias contains " + "out-of-vocab token id") + + return LogitBiasLogitsProcessor(logit_bias_index, + logit_bias_value, + dtype=dtype) def get_logits_processors( @@ -62,30 +107,10 @@ def get_logits_processors( ) -> List[LogitsProcessor]: logits_processors: List[LogitsProcessor] = [] if logit_bias: - try: - # Convert token_id to integer - # Clamp the bias between -100 and 100 per OpenAI API spec - logit_bias_index = [int(token_id) for token_id in logit_bias] - logit_bias_value = [ - min(100.0, max(-100.0, bias)) for bias in logit_bias.values() - ] - except ValueError as exc: - raise ValueError( - "Found token_id in logit_bias that is not " - "an integer or string representing an integer") from exc - - # Check if token_id is within the vocab size - for token_id in logit_bias_index: - if token_id < 0 or token_id >= len(tokenizer): - raise ValueError(f"token_id {token_id} in logit_bias contains " - "out-of-vocab token id") - - clamped_logit_bias: Dict[str, torch.Tensor] = { - "index": torch.tensor(logit_bias_index), - "value": torch.tensor(logit_bias_value, dtype=dtype) - } logits_processors.append( - partial(logit_bias_logits_processor, clamped_logit_bias)) + _get_logit_bias_logits_processor(tuple(logit_bias.keys()), + tuple(logit_bias.values()), + len(tokenizer), dtype)) if allowed_token_ids is not None: logits_processors.append( From fe6e1feeb5c5e1563ccc588429c6904a96b8f641 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Mon, 17 Feb 2025 15:47:47 +0800 Subject: [PATCH 5/6] fix typing warning Signed-off-by: Xu Song --- vllm/entrypoints/openai/logits_processors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index a876dd8f571d..312b971cd3d8 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -69,8 +69,8 @@ def __call__( @lru_cache(maxsize=32) def _get_logit_bias_logits_processor( - logit_bias_index: Optional[Union[Tuple[int], Tuple[str]]], - logit_bias_value: Optional[Tuple[float]], + logit_bias_index: Union[Tuple[int], Tuple[str]], + logit_bias_value: Tuple[float], vocab_size: int, dtype: Union[str, torch.dtype], ) -> LogitsProcessor: From cd9f33fa750fbeabfd8ee8526b5a52ae99580f28 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Mon, 17 Feb 2025 16:02:26 +0800 Subject: [PATCH 6/6] fix typing warning Signed-off-by: Xu Song --- vllm/entrypoints/openai/logits_processors.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 312b971cd3d8..7629c8282c6c 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -77,10 +77,10 @@ def _get_logit_bias_logits_processor( try: # Convert token_id to integer # Clamp the bias between -100 and 100 per OpenAI API spec - logit_bias_index: List[int] = [ + clamped_logit_bias_index: List[int] = [ int(token_id) for token_id in logit_bias_index ] - logit_bias_value: List[float] = [ + clamped_logit_bias_value: List[float] = [ min(100.0, max(-100.0, bias)) for bias in logit_bias_value ] except ValueError as exc: @@ -89,13 +89,13 @@ def _get_logit_bias_logits_processor( "an integer or string representing an integer") from exc # Check if token_id is within the vocab size - for token_id in logit_bias_index: + for token_id in clamped_logit_bias_index: if token_id < 0 or token_id >= vocab_size: raise ValueError(f"token_id {token_id} in logit_bias contains " "out-of-vocab token id") - return LogitBiasLogitsProcessor(logit_bias_index, - logit_bias_value, + return LogitBiasLogitsProcessor(clamped_logit_bias_index, + clamped_logit_bias_value, dtype=dtype)