5
5
import torch
6
6
import torch .nn as nn
7
7
8
+ from vllm .utils import async_tensor_h2d , is_pin_memory_available
8
9
from vllm .v1 .outputs import LogprobsTensors , SamplerOutput
9
10
from vllm .v1 .sample .metadata import SamplingMetadata
10
11
from vllm .v1 .sample .ops .bad_words import apply_bad_words
@@ -20,6 +21,7 @@ class Sampler(nn.Module):
20
21
def __init__ (self ):
21
22
super ().__init__ ()
22
23
self .topk_topp_sampler = TopKTopPSampler ()
24
+ self .pin_memory = is_pin_memory_available ()
23
25
24
26
def forward (
25
27
self ,
@@ -232,6 +234,10 @@ def apply_logits_bias(
232
234
# One idea is implement this as a PyTorch C++ op, and we may
233
235
# even optimize the logit_bias layout.
234
236
237
+ rows : list [int ] = []
238
+ cols : list [int ] = []
239
+ vals : list [float ] = []
240
+
235
241
# Get vocabulary size from logits
236
242
vocab_size = logits .shape [- 1 ]
237
243
@@ -244,7 +250,16 @@ def apply_logits_bias(
244
250
f"token_id { token_id } in logit_bias contains "
245
251
f"out-of-vocab token id. Vocabulary size: "
246
252
f"{ vocab_size } " )
247
- logits [i , token_id ] += bias
253
+ rows .append (i )
254
+ cols .append (token_id )
255
+ vals .append (bias )
256
+
257
+ if rows :
258
+ indices = async_tensor_h2d ([rows , cols ], torch .int64 ,
259
+ logits .device , self .pin_memory )
260
+ values = async_tensor_h2d (vals , torch .float , logits .device ,
261
+ self .pin_memory )
262
+ logits .index_put_ (tuple (indices ), values = values , accumulate = True )
248
263
return logits
249
264
250
265
def apply_allowed_token_ids (
0 commit comments