Skip to content

Commit 13182ec

Browse files
Small update to vllm batch (#419)
1 parent 242d62b commit 13182ec

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

model-engine/model_engine_server/inference/batch_inference/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
ray==2.6.3
2-
git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm
2+
#git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm
3+
vllm==0.2.5
34
pydantic==1.10.13
45
boto3==1.34.15
56
smart-open==6.4.0

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
CreateBatchCompletionsRequestContent,
1414
TokenOutput,
1515
)
16+
from tqdm import tqdm
1617

1718
CONFIG_FILE = os.getenv("CONFIG_FILE")
1819
AWS_REGION = os.getenv("AWS_REGION", "us-west-2")
@@ -123,11 +124,16 @@ async def batch_inference():
123124

124125
results_generators = await generate_with_vllm(request, content, model, job_index)
125126

127+
bar = tqdm(total=len(content.prompts), desc="Processed prompts")
128+
126129
outputs = []
127130
for generator in results_generators:
128131
last_output_text = ""
129132
tokens = []
130133
async for request_output in generator:
134+
if request_output.finished:
135+
bar.update(1)
136+
131137
token_text = request_output.outputs[-1].text[len(last_output_text) :]
132138
log_probs = (
133139
request_output.outputs[0].logprobs[-1] if content.return_token_log_probs else None
@@ -155,6 +161,8 @@ async def batch_inference():
155161

156162
outputs.append(output.dict())
157163

164+
bar.close()
165+
158166
if request.data_parallelism == 1:
159167
with smart_open.open(request.output_data_path, "w") as f:
160168
f.write(json.dumps(outputs))
@@ -178,6 +186,7 @@ async def generate_with_vllm(request, content, model, job_index):
178186
quantization=request.model_config.quantize,
179187
tensor_parallel_size=request.model_config.num_shards,
180188
seed=request.model_config.seed or 0,
189+
disable_log_requests=True,
181190
)
182191

183192
llm = AsyncLLMEngine.from_engine_args(engine_args)

0 commit comments

Comments
 (0)