Skip to content

Commit a37d75b

Browse files
authored
[Front-end] microbatch tokenization (#19334)
Signed-off-by: zt2370 <ztang2370@gmail.com>
1 parent edd270b commit a37d75b

File tree

3 files changed

+288
-64
lines changed

3 files changed

+288
-64
lines changed

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Any, Optional
88
from unittest.mock import MagicMock
99

10+
import pytest
11+
1012
from vllm.config import MultiModalConfig
1113
from vllm.engine.multiprocessing.client import MQLLMEngineClient
1214
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
@@ -73,7 +75,8 @@ def test_async_serving_chat_init():
7375
assert serving_completion.chat_template == CHAT_TEMPLATE
7476

7577

76-
def test_serving_chat_should_set_correct_max_tokens():
78+
@pytest.mark.asyncio
79+
async def test_serving_chat_should_set_correct_max_tokens():
7780
mock_engine = MagicMock(spec=MQLLMEngineClient)
7881
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
7982
mock_engine.errored = False
@@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
8891
chat_template=CHAT_TEMPLATE,
8992
chat_template_content_format="auto",
9093
request_logger=None)
94+
9195
req = ChatCompletionRequest(
9296
model=MODEL_NAME,
9397
messages=[{
@@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
98102
)
99103

100104
with suppress(Exception):
101-
asyncio.run(serving_chat.create_chat_completion(req))
105+
await serving_chat.create_chat_completion(req)
102106

103107
assert mock_engine.generate.call_args.args[1].max_tokens == 93
104108

105109
req.max_tokens = 10
106110
with suppress(Exception):
107-
asyncio.run(serving_chat.create_chat_completion(req))
111+
await serving_chat.create_chat_completion(req)
108112

109113
assert mock_engine.generate.call_args.args[1].max_tokens == 10
110114

@@ -143,23 +147,23 @@ def test_serving_chat_should_set_correct_max_tokens():
143147
)
144148

145149
with suppress(Exception):
146-
asyncio.run(serving_chat.create_chat_completion(req))
150+
await serving_chat.create_chat_completion(req)
147151

148152
assert mock_engine.generate.call_args.args[1].max_tokens == 10
149153

150154
# Test Case 2: Request's max_tokens set higher than server accepts
151155
req.max_tokens = 15
152156

153157
with suppress(Exception):
154-
asyncio.run(serving_chat.create_chat_completion(req))
158+
await serving_chat.create_chat_completion(req)
155159

156160
assert mock_engine.generate.call_args.args[1].max_tokens == 10
157161

158162
# Test Case 3: Request's max_tokens set lower than server accepts
159163
req.max_tokens = 5
160164

161165
with suppress(Exception):
162-
asyncio.run(serving_chat.create_chat_completion(req))
166+
await serving_chat.create_chat_completion(req)
163167

164168
assert mock_engine.generate.call_args.args[1].max_tokens == 5
165169

@@ -198,28 +202,29 @@ def test_serving_chat_should_set_correct_max_tokens():
198202
)
199203

200204
with suppress(Exception):
201-
asyncio.run(serving_chat.create_chat_completion(req))
205+
await serving_chat.create_chat_completion(req)
202206

203207
assert mock_engine.generate.call_args.args[1].max_tokens == 93
204208

205209
# Test Case 2: Request's max_tokens set higher than server accepts
206210
req.max_tokens = 100
207211

208212
with suppress(Exception):
209-
asyncio.run(serving_chat.create_chat_completion(req))
213+
await serving_chat.create_chat_completion(req)
210214

211215
assert mock_engine.generate.call_args.args[1].max_tokens == 93
212216

213217
# Test Case 3: Request's max_tokens set lower than server accepts
214218
req.max_tokens = 5
215219

216220
with suppress(Exception):
217-
asyncio.run(serving_chat.create_chat_completion(req))
221+
await serving_chat.create_chat_completion(req)
218222

219223
assert mock_engine.generate.call_args.args[1].max_tokens == 5
220224

221225

222-
def test_serving_chat_could_load_correct_generation_config():
226+
@pytest.mark.asyncio
227+
async def test_serving_chat_could_load_correct_generation_config():
223228

224229
mock_model_config = MockModelConfig()
225230
mock_model_config.diff_sampling_param = {
@@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
242247
chat_template=CHAT_TEMPLATE,
243248
chat_template_content_format="auto",
244249
request_logger=None)
250+
245251
req = ChatCompletionRequest(
246252
model=MODEL_NAME,
247253
messages=[{
@@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
252258
)
253259

254260
with suppress(Exception):
255-
asyncio.run(serving_chat.create_chat_completion(req))
261+
await serving_chat.create_chat_completion(req)
256262

257263
assert mock_engine.generate.call_args.args[1].temperature == 0.5
258264
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
261267
req.temperature = 0.1
262268

263269
with suppress(Exception):
264-
asyncio.run(serving_chat.create_chat_completion(req))
270+
await serving_chat.create_chat_completion(req)
265271

266272
assert mock_engine.generate.call_args.args[1].temperature == 0.1
267273
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
270276
req.temperature = 0.0
271277

272278
with suppress(Exception):
273-
asyncio.run(serving_chat.create_chat_completion(req))
279+
await serving_chat.create_chat_completion(req)
274280

275281
assert mock_engine.generate.call_args.args[1].temperature == 0.0
276282
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
277283

278284

279-
def test_serving_chat_did_set_correct_cache_salt():
285+
@pytest.mark.asyncio
286+
async def test_serving_chat_did_set_correct_cache_salt():
280287
mock_model_config = MockModelConfig()
281288

282289
mock_engine = MagicMock(spec=MQLLMEngineClient)
@@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
306313

307314
# By default cache_salt in the engine prompt is not set
308315
with suppress(Exception):
309-
asyncio.run(serving_chat.create_chat_completion(req))
316+
await serving_chat.create_chat_completion(req)
310317
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
311318

312319
# Test with certain cache_salt
313320
req.cache_salt = "test_salt"
314321
with suppress(Exception):
315-
asyncio.run(serving_chat.create_chat_completion(req))
322+
await serving_chat.create_chat_completion(req)
316323
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"

0 commit comments

Comments
 (0)