13
13
CreateBatchCompletionsRequestContent ,
14
14
TokenOutput ,
15
15
)
16
+ from tqdm import tqdm
16
17
17
18
CONFIG_FILE = os .getenv ("CONFIG_FILE" )
18
19
AWS_REGION = os .getenv ("AWS_REGION" , "us-west-2" )
@@ -123,11 +124,16 @@ async def batch_inference():
123
124
124
125
results_generators = await generate_with_vllm (request , content , model , job_index )
125
126
127
+ bar = tqdm (total = len (content .prompts ), desc = "Processed prompts" )
128
+
126
129
outputs = []
127
130
for generator in results_generators :
128
131
last_output_text = ""
129
132
tokens = []
130
133
async for request_output in generator :
134
+ if request_output .finished :
135
+ bar .update (1 )
136
+
131
137
token_text = request_output .outputs [- 1 ].text [len (last_output_text ) :]
132
138
log_probs = (
133
139
request_output .outputs [0 ].logprobs [- 1 ] if content .return_token_log_probs else None
@@ -155,6 +161,8 @@ async def batch_inference():
155
161
156
162
outputs .append (output .dict ())
157
163
164
+ bar .close ()
165
+
158
166
if request .data_parallelism == 1 :
159
167
with smart_open .open (request .output_data_path , "w" ) as f :
160
168
f .write (json .dumps (outputs ))
@@ -178,6 +186,7 @@ async def generate_with_vllm(request, content, model, job_index):
178
186
quantization = request .model_config .quantize ,
179
187
tensor_parallel_size = request .model_config .num_shards ,
180
188
seed = request .model_config .seed or 0 ,
189
+ disable_log_requests = True ,
181
190
)
182
191
183
192
llm = AsyncLLMEngine .from_engine_args (engine_args )
0 commit comments