diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f842581bf551..cfbcc8a07c09 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2103,7 +2103,8 @@ def _build_logits_processors( processors = get_openai_logits_processors( logit_bias=sampling_params.logit_bias, allowed_token_ids=sampling_params.allowed_token_ids, - tokenizer=tokenizer) + tokenizer=tokenizer, + dtype=self.model_config.dtype) logits_processors.extend(processors) # Unset so these don't get passed down to the model diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 04d5091a9681..d8455ec6c21c 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable -from functools import lru_cache, partial +from functools import lru_cache from typing import Optional, Union import torch @@ -43,43 +43,75 @@ def _get_allowed_token_ids_logits_processor( return AllowedTokenIdsLogitsProcessor(allowed_token_ids) -def logit_bias_logits_processor( - logit_bias: dict[int, float], - token_ids: list[int], - logits: torch.Tensor, -) -> torch.Tensor: - for token_id, bias in logit_bias.items(): - logits[token_id] += bias - return logits +class LogitBiasLogitsProcessor: + """Logits processor for applying biases to logits. + It lets you control whether the model is more or less likely to + generate a specific token. + """ + + def __init__(self, logit_bias_index: list[int], + logit_bias_value: list[float], dtype: Union[str, + torch.dtype]): + self.logit_bias_index: torch.Tensor = torch.tensor(logit_bias_index) + self.logit_bias_value: torch.Tensor = torch.tensor(logit_bias_value, + dtype=dtype) + + def __call__( + self, + token_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + if self.logit_bias_value.device != logits.device: + self.logit_bias_index = self.logit_bias_index.to(logits.device) + self.logit_bias_value = self.logit_bias_value.to(logits.device) + logits.index_add_(0, self.logit_bias_index, self.logit_bias_value) + return logits + + +@lru_cache(maxsize=32) +def _get_logit_bias_logits_processor( + logit_bias_index: Union[tuple[int], tuple[str]], + logit_bias_value: tuple[float], + vocab_size: int, + dtype: Union[str, torch.dtype], +) -> LogitsProcessor: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + clamped_logit_bias_index: list[int] = [ + int(token_id) for token_id in logit_bias_index + ] + clamped_logit_bias_value: list[float] = [ + min(100.0, max(-100.0, bias)) for bias in logit_bias_value + ] + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + + # Check if token_id is within the vocab size + for token_id in clamped_logit_bias_index: + if token_id < 0 or token_id >= vocab_size: + raise ValueError(f"token_id {token_id} in logit_bias contains " + "out-of-vocab token id") + + return LogitBiasLogitsProcessor(clamped_logit_bias_index, + clamped_logit_bias_value, + dtype=dtype) def get_logits_processors( logit_bias: Optional[Union[dict[int, float], dict[str, float]]], allowed_token_ids: Optional[list[int]], tokenizer: AnyTokenizer, + dtype: Union[str, torch.dtype], ) -> list[LogitsProcessor]: logits_processors: list[LogitsProcessor] = [] if logit_bias: - try: - # Convert token_id to integer - # Clamp the bias between -100 and 100 per OpenAI API spec - clamped_logit_bias: dict[int, float] = { - int(token_id): min(100.0, max(-100.0, bias)) - for token_id, bias in logit_bias.items() - } - except ValueError as exc: - raise ValueError( - "Found token_id in logit_bias that is not " - "an integer or string representing an integer") from exc - - # Check if token_id is within the vocab size - for token_id, bias in clamped_logit_bias.items(): - if token_id < 0 or token_id >= len(tokenizer): - raise ValueError(f"token_id {token_id} in logit_bias contains " - "out-of-vocab token id") - logits_processors.append( - partial(logit_bias_logits_processor, clamped_logit_bias)) + _get_logit_bias_logits_processor(tuple(logit_bias.keys()), + tuple(logit_bias.values()), + len(tokenizer), dtype)) if allowed_token_ids is not None: logits_processors.append(