7
7
import torch
8
8
from vllm .sampling_params import SamplingParams
9
9
from vllm .utils import is_pin_memory_available , make_tensor_with_pad
10
+ from vllm .v1 .sample .logits_processor import (BatchUpdateBuilder ,
11
+ init_builtin_logitsprocs )
12
+ from vllm .v1 .sample .metadata import SamplingMetadata
10
13
11
- from vllm_spyre .v1 .sample .metadata import SamplingMetadata
12
14
from vllm_spyre .v1 .worker .spyre_input_batch import (CachedRequestState ,
13
15
InputBatch )
14
16
@@ -56,20 +58,27 @@ def _construct_expected_sampling_metadata(
56
58
repetition_penalties = [1.0 for _ in range (num_reqs )]
57
59
top_k = [0 for _ in range (num_reqs )]
58
60
top_p = [0.0 for _ in range (num_reqs )]
59
- min_p = [0.0 for _ in range (num_reqs )]
60
61
temperature = [0.0 for _ in range (num_reqs )]
61
- min_tokens = {}
62
- logit_bias = [None ] * num_reqs
63
62
allowed_token_ids_mask = torch .zeros (num_reqs ,
64
63
VOCAB_SIZE ,
65
64
dtype = torch .bool ,
66
65
device = device )
67
66
67
+ batch_update_builder = BatchUpdateBuilder ()
68
+ logitsprocs = init_builtin_logitsprocs (pin_memory_available = False ,
69
+ max_num_reqs = len (reqs ) + 1 ,
70
+ device = device )
71
+
68
72
bad_words_token_ids = {}
69
73
for req in reqs :
70
74
if req .req_id not in req_ids_retained :
71
75
continue
72
76
index_in_input_batch = input_batch .req_id_to_dense_index (req .req_id )
77
+
78
+ params = req .sampling_params
79
+ batch_update_builder .added .append (
80
+ (index_in_input_batch , params , req .output_token_ids ))
81
+
73
82
output_token_ids [index_in_input_batch ] = req .output_token_ids
74
83
prompt_token_ids [index_in_input_batch ] = req .prompt_token_ids
75
84
presence_penalties [
@@ -80,19 +89,18 @@ def _construct_expected_sampling_metadata(
80
89
req .sampling_params .repetition_penalty )
81
90
top_k [index_in_input_batch ] = req .sampling_params .top_k
82
91
top_p [index_in_input_batch ] = req .sampling_params .top_p
83
- min_p [index_in_input_batch ] = req .sampling_params .min_p
84
92
temperature [index_in_input_batch ] = req .sampling_params .temperature
85
- min_tokens [index_in_input_batch ] = (
86
- req .sampling_params .min_tokens ,
87
- req .sampling_params .all_stop_token_ids )
88
- logit_bias [index_in_input_batch ] = req .sampling_params .logit_bias
89
93
if req .sampling_params .allowed_token_ids :
90
94
allowed_token_ids_mask [index_in_input_batch ][
91
95
req .sampling_params .allowed_token_ids ] = True
92
96
if req .sampling_params .bad_words_token_ids :
93
97
bad_words_token_ids [
94
98
index_in_input_batch ] = req .sampling_params .bad_words_token_ids
95
99
100
+ batch_update = batch_update_builder .get_and_reset (num_reqs )
101
+ for logit_proc in logitsprocs .all :
102
+ logit_proc .update_state (batch_update )
103
+
96
104
return SamplingMetadata (
97
105
temperature = torch .tensor (temperature , dtype = torch .float ,
98
106
device = device ),
@@ -102,8 +110,6 @@ def _construct_expected_sampling_metadata(
102
110
top_p , dtype = torch .float , device = device ),
103
111
top_k = None if all (x == 0 for x in top_k ) else torch .tensor (
104
112
top_k , dtype = torch .int , device = device ),
105
- min_p = None if all (x == 0.0 for x in min_p ) else torch .tensor (
106
- min_p , dtype = torch .float , device = device ),
107
113
generators = {},
108
114
max_num_logprobs = 0 ,
109
115
prompt_token_ids = make_tensor_with_pad (
@@ -122,13 +128,12 @@ def _construct_expected_sampling_metadata(
122
128
dtype = torch .float ,
123
129
device = device ),
124
130
output_token_ids = output_token_ids ,
125
- min_tokens = min_tokens ,
126
131
no_penalties = (all (x == 0 for x in presence_penalties )
127
132
and all (x == 0 for x in frequency_penalties )
128
133
and all (x == 1 for x in repetition_penalties )),
129
- logit_bias = logit_bias ,
130
134
allowed_token_ids_mask = allowed_token_ids_mask ,
131
135
bad_words_token_ids = bad_words_token_ids ,
136
+ logitsprocs = logitsprocs ,
132
137
)
133
138
134
139
@@ -196,10 +201,8 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
196
201
sampling_metadata .prompt_token_ids )
197
202
assert (expected_sampling_metadata .output_token_ids ==
198
203
sampling_metadata .output_token_ids )
199
- assert expected_sampling_metadata .min_tokens == sampling_metadata .min_tokens
200
204
assert expected_sampling_metadata .no_penalties == \
201
205
sampling_metadata .no_penalties
202
- assert expected_sampling_metadata .logit_bias == sampling_metadata .logit_bias
203
206
if sampling_metadata .allowed_token_ids_mask :
204
207
assert torch .allclose (
205
208
expected_sampling_metadata .allowed_token_ids_mask ,
0 commit comments