Skip to content

Commit a50b0ea

Browse files
committed
optimize logit_bias
Signed-off-by: Xu Song <xusong.vip@gmail.com>
1 parent eb39b11 commit a50b0ea

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

vllm/entrypoints/openai/logits_processors.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,24 @@ def get_logits_processors(
6363
try:
6464
# Convert token_id to integer
6565
# Clamp the bias between -100 and 100 per OpenAI API spec
66-
clamped_logit_bias: Dict[int, float] = {
67-
int(token_id): min(100.0, max(-100.0, bias))
68-
for token_id, bias in logit_bias.items()
69-
}
66+
logit_bias_index = [int(token_id) for token_id in logit_bias]
67+
logit_bias_value = [
68+
min(100.0, max(-100.0, bias)) for bias in logit_bias.values()
69+
]
7070
except ValueError as exc:
7171
raise ValueError(
7272
"Found token_id in logit_bias that is not "
7373
"an integer or string representing an integer") from exc
7474

7575
# Check if token_id is within the vocab size
76-
for token_id, bias in clamped_logit_bias.items():
76+
for token_id in logit_bias_index:
7777
if token_id < 0 or token_id >= len(tokenizer):
7878
raise ValueError(f"token_id {token_id} in logit_bias contains "
7979
"out-of-vocab token id")
8080

81-
clamped_logit_bias = {
82-
"index": torch.tensor(list(clamped_logit_bias.keys())),
83-
"value": torch.tensor(list(clamped_logit_bias.values()),
84-
dtype=dtype)
81+
clamped_logit_bias: Dict[str, torch.Tensor] = {
82+
"index": torch.tensor(logit_bias_index),
83+
"value": torch.tensor(logit_bias_value, dtype=dtype)
8584
}
8685
logits_processors.append(
8786
partial(logit_bias_logits_processor, clamped_logit_bias))

0 commit comments

Comments
 (0)