7
7
from typing import Any , Optional
8
8
from unittest .mock import MagicMock
9
9
10
+ import pytest
11
+
10
12
from vllm .config import MultiModalConfig
11
13
from vllm .engine .multiprocessing .client import MQLLMEngineClient
12
14
from vllm .entrypoints .openai .protocol import ChatCompletionRequest
@@ -73,7 +75,8 @@ def test_async_serving_chat_init():
73
75
assert serving_completion .chat_template == CHAT_TEMPLATE
74
76
75
77
76
- def test_serving_chat_should_set_correct_max_tokens ():
78
+ @pytest .mark .asyncio
79
+ async def test_serving_chat_should_set_correct_max_tokens ():
77
80
mock_engine = MagicMock (spec = MQLLMEngineClient )
78
81
mock_engine .get_tokenizer .return_value = get_tokenizer (MODEL_NAME )
79
82
mock_engine .errored = False
@@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
88
91
chat_template = CHAT_TEMPLATE ,
89
92
chat_template_content_format = "auto" ,
90
93
request_logger = None )
94
+
91
95
req = ChatCompletionRequest (
92
96
model = MODEL_NAME ,
93
97
messages = [{
@@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
98
102
)
99
103
100
104
with suppress (Exception ):
101
- asyncio . run ( serving_chat .create_chat_completion (req ) )
105
+ await serving_chat .create_chat_completion (req )
102
106
103
107
assert mock_engine .generate .call_args .args [1 ].max_tokens == 93
104
108
105
109
req .max_tokens = 10
106
110
with suppress (Exception ):
107
- asyncio . run ( serving_chat .create_chat_completion (req ) )
111
+ await serving_chat .create_chat_completion (req )
108
112
109
113
assert mock_engine .generate .call_args .args [1 ].max_tokens == 10
110
114
@@ -143,23 +147,23 @@ def test_serving_chat_should_set_correct_max_tokens():
143
147
)
144
148
145
149
with suppress (Exception ):
146
- asyncio . run ( serving_chat .create_chat_completion (req ) )
150
+ await serving_chat .create_chat_completion (req )
147
151
148
152
assert mock_engine .generate .call_args .args [1 ].max_tokens == 10
149
153
150
154
# Test Case 2: Request's max_tokens set higher than server accepts
151
155
req .max_tokens = 15
152
156
153
157
with suppress (Exception ):
154
- asyncio . run ( serving_chat .create_chat_completion (req ) )
158
+ await serving_chat .create_chat_completion (req )
155
159
156
160
assert mock_engine .generate .call_args .args [1 ].max_tokens == 10
157
161
158
162
# Test Case 3: Request's max_tokens set lower than server accepts
159
163
req .max_tokens = 5
160
164
161
165
with suppress (Exception ):
162
- asyncio . run ( serving_chat .create_chat_completion (req ) )
166
+ await serving_chat .create_chat_completion (req )
163
167
164
168
assert mock_engine .generate .call_args .args [1 ].max_tokens == 5
165
169
@@ -198,28 +202,29 @@ def test_serving_chat_should_set_correct_max_tokens():
198
202
)
199
203
200
204
with suppress (Exception ):
201
- asyncio . run ( serving_chat .create_chat_completion (req ) )
205
+ await serving_chat .create_chat_completion (req )
202
206
203
207
assert mock_engine .generate .call_args .args [1 ].max_tokens == 93
204
208
205
209
# Test Case 2: Request's max_tokens set higher than server accepts
206
210
req .max_tokens = 100
207
211
208
212
with suppress (Exception ):
209
- asyncio . run ( serving_chat .create_chat_completion (req ) )
213
+ await serving_chat .create_chat_completion (req )
210
214
211
215
assert mock_engine .generate .call_args .args [1 ].max_tokens == 93
212
216
213
217
# Test Case 3: Request's max_tokens set lower than server accepts
214
218
req .max_tokens = 5
215
219
216
220
with suppress (Exception ):
217
- asyncio . run ( serving_chat .create_chat_completion (req ) )
221
+ await serving_chat .create_chat_completion (req )
218
222
219
223
assert mock_engine .generate .call_args .args [1 ].max_tokens == 5
220
224
221
225
222
- def test_serving_chat_could_load_correct_generation_config ():
226
+ @pytest .mark .asyncio
227
+ async def test_serving_chat_could_load_correct_generation_config ():
223
228
224
229
mock_model_config = MockModelConfig ()
225
230
mock_model_config .diff_sampling_param = {
@@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
242
247
chat_template = CHAT_TEMPLATE ,
243
248
chat_template_content_format = "auto" ,
244
249
request_logger = None )
250
+
245
251
req = ChatCompletionRequest (
246
252
model = MODEL_NAME ,
247
253
messages = [{
@@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
252
258
)
253
259
254
260
with suppress (Exception ):
255
- asyncio . run ( serving_chat .create_chat_completion (req ) )
261
+ await serving_chat .create_chat_completion (req )
256
262
257
263
assert mock_engine .generate .call_args .args [1 ].temperature == 0.5
258
264
assert mock_engine .generate .call_args .args [1 ].repetition_penalty == 1.05
@@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
261
267
req .temperature = 0.1
262
268
263
269
with suppress (Exception ):
264
- asyncio . run ( serving_chat .create_chat_completion (req ) )
270
+ await serving_chat .create_chat_completion (req )
265
271
266
272
assert mock_engine .generate .call_args .args [1 ].temperature == 0.1
267
273
assert mock_engine .generate .call_args .args [1 ].repetition_penalty == 1.05
@@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
270
276
req .temperature = 0.0
271
277
272
278
with suppress (Exception ):
273
- asyncio . run ( serving_chat .create_chat_completion (req ) )
279
+ await serving_chat .create_chat_completion (req )
274
280
275
281
assert mock_engine .generate .call_args .args [1 ].temperature == 0.0
276
282
assert mock_engine .generate .call_args .args [1 ].repetition_penalty == 1.05
277
283
278
284
279
- def test_serving_chat_did_set_correct_cache_salt ():
285
+ @pytest .mark .asyncio
286
+ async def test_serving_chat_did_set_correct_cache_salt ():
280
287
mock_model_config = MockModelConfig ()
281
288
282
289
mock_engine = MagicMock (spec = MQLLMEngineClient )
@@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
306
313
307
314
# By default cache_salt in the engine prompt is not set
308
315
with suppress (Exception ):
309
- asyncio . run ( serving_chat .create_chat_completion (req ) )
316
+ await serving_chat .create_chat_completion (req )
310
317
assert "cache_salt" not in mock_engine .generate .call_args .args [0 ]
311
318
312
319
# Test with certain cache_salt
313
320
req .cache_salt = "test_salt"
314
321
with suppress (Exception ):
315
- asyncio . run ( serving_chat .create_chat_completion (req ) )
322
+ await serving_chat .create_chat_completion (req )
316
323
assert mock_engine .generate .call_args .args [0 ]["cache_salt" ] == "test_salt"
0 commit comments