Skip to content

Commit eb39b11

Browse files
committed
optimize logit_bias
Signed-off-by: Xu Song <xusong.vip@gmail.com>
1 parent 9206b3d commit eb39b11

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

vllm/engine/llm_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2003,7 +2003,8 @@ def _build_logits_processors(
20032003
processors = get_openai_logits_processors(
20042004
logit_bias=sampling_params.logit_bias,
20052005
allowed_token_ids=sampling_params.allowed_token_ids,
2006-
tokenizer=tokenizer)
2006+
tokenizer=tokenizer,
2007+
dtype=self.model_config.dtype)
20072008
logits_processors.extend(processors)
20082009

20092010
# Unset so these don't get passed down to the model

vllm/entrypoints/openai/logits_processors.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,20 @@ def _get_allowed_token_ids_logits_processor(
4343

4444

4545
def logit_bias_logits_processor(
46-
logit_bias: Dict[int, float],
46+
logit_bias: Dict[str, torch.Tensor],
4747
token_ids: List[int],
4848
logits: torch.Tensor,
4949
) -> torch.Tensor:
50-
for token_id, bias in logit_bias.items():
51-
logits[token_id] += bias
50+
logits.index_add_(0, logit_bias["index"].to(logits.device),
51+
logit_bias["value"].to(logits.device))
5252
return logits
5353

5454

5555
def get_logits_processors(
5656
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
5757
allowed_token_ids: Optional[List[int]],
5858
tokenizer: AnyTokenizer,
59+
dtype: Union[str, torch.dtype],
5960
) -> List[LogitsProcessor]:
6061
logits_processors: List[LogitsProcessor] = []
6162
if logit_bias:
@@ -77,6 +78,11 @@ def get_logits_processors(
7778
raise ValueError(f"token_id {token_id} in logit_bias contains "
7879
"out-of-vocab token id")
7980

81+
clamped_logit_bias = {
82+
"index": torch.tensor(list(clamped_logit_bias.keys())),
83+
"value": torch.tensor(list(clamped_logit_bias.values()),
84+
dtype=dtype)
85+
}
8086
logits_processors.append(
8187
partial(logit_bias_logits_processor, clamped_logit_bias))
8288

0 commit comments

Comments
 (0)