Skip to content

Commit 8267f99

Browse files
authored
improve logits bias (#19041)
1 parent 7353492 commit 8267f99

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

vllm/v1/sample/sampler.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.nn as nn
77

8+
from vllm.utils import async_tensor_h2d, is_pin_memory_available
89
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
910
from vllm.v1.sample.metadata import SamplingMetadata
1011
from vllm.v1.sample.ops.bad_words import apply_bad_words
@@ -20,6 +21,7 @@ class Sampler(nn.Module):
2021
def __init__(self):
2122
super().__init__()
2223
self.topk_topp_sampler = TopKTopPSampler()
24+
self.pin_memory = is_pin_memory_available()
2325

2426
def forward(
2527
self,
@@ -232,6 +234,10 @@ def apply_logits_bias(
232234
# One idea is implement this as a PyTorch C++ op, and we may
233235
# even optimize the logit_bias layout.
234236

237+
rows: list[int] = []
238+
cols: list[int] = []
239+
vals: list[float] = []
240+
235241
# Get vocabulary size from logits
236242
vocab_size = logits.shape[-1]
237243

@@ -244,7 +250,16 @@ def apply_logits_bias(
244250
f"token_id {token_id} in logit_bias contains "
245251
f"out-of-vocab token id. Vocabulary size: "
246252
f"{vocab_size}")
247-
logits[i, token_id] += bias
253+
rows.append(i)
254+
cols.append(token_id)
255+
vals.append(bias)
256+
257+
if rows:
258+
indices = async_tensor_h2d([rows, cols], torch.int64,
259+
logits.device, self.pin_memory)
260+
values = async_tensor_h2d(vals, torch.float, logits.device,
261+
self.pin_memory)
262+
logits.index_put_(tuple(indices), values=values, accumulate=True)
248263
return logits
249264

250265
def apply_allowed_token_ids(

0 commit comments

Comments
 (0)