21
21
import uuid
22
22
from typing import List , Optional
23
23
24
+ import msgpack
24
25
import aiozmq
25
26
from aiozmq import zmq
26
27
@@ -143,6 +144,8 @@ async def chat_completion_stream_generator(
143
144
dealer .write ([b"" , request_id .encode ('utf-8' )])
144
145
choices = []
145
146
current_waiting_time = 0
147
+ if request .metadata is not None :
148
+ enable_thinking = request .metadata .get ("enable_thinking" )
146
149
while num_choices > 0 :
147
150
try :
148
151
raw_data = await asyncio .wait_for (dealer .read (), timeout = 10 )
@@ -158,102 +161,106 @@ async def chat_completion_stream_generator(
158
161
raise ValueError (f"Engine is not healthy: { msg } " )
159
162
else :
160
163
current_waiting_time = 0
161
- await asyncio .sleep (0.1 )
164
+ await asyncio .sleep (0.01 )
162
165
continue
166
+ response = msgpack .unpackb (raw_data [- 1 ])
167
+ for res in response :
168
+ if res .get ("error_code" , 200 ) != 200 :
169
+ raise ValueError ("{}" .format (res ["error_msg" ]))
170
+
171
+ self .engine_client .data_processor .process_response_dict (
172
+ res , stream = True , enable_thinking = enable_thinking )
163
173
164
- res = json .loads (raw_data [- 1 ].decode ('utf-8' ))
165
- if res .get ("error_code" , 200 ) != 200 :
166
- raise ValueError ("{}" .format (res ["error_msg" ]))
167
- if request .metadata is not None :
168
- enable_thinking = request .metadata .get ("enable_thinking" )
169
- self .engine_client .data_processor .process_response_dict (
170
- res , stream = True , enable_thinking = enable_thinking )
171
-
172
- if res ['metrics' ]['first_token_time' ] is not None :
173
- arrival_time = res ['metrics' ]['first_token_time' ]
174
- inference_start_time = res ['metrics' ]['inference_start_time' ]
175
- else :
176
- arrival_time = res ['metrics' ]['arrival_time' ] - inference_start_time
177
- if first_iteration :
178
- num_prompt_tokens = len (prompt_token_ids )
179
- num_cached_tokens = res .get ("num_cached_tokens" , 0 )
180
- for i in range (num_choices ):
181
- choice = ChatCompletionResponseStreamChoice (
182
- index = i ,
183
- delta = DeltaMessage (role = "assistant" , content = "" , reasoning_content = "" , tool_calls = None )
174
+ if res ['metrics' ]['first_token_time' ] is not None :
175
+ arrival_time = res ['metrics' ]['first_token_time' ]
176
+ inference_start_time = res ['metrics' ]['inference_start_time' ]
177
+ else :
178
+ arrival_time = res ['metrics' ]['arrival_time' ] - inference_start_time
179
+ if first_iteration :
180
+ num_prompt_tokens = len (prompt_token_ids )
181
+ num_cached_tokens = res .get ("num_cached_tokens" , 0 )
182
+ for i in range (num_choices ):
183
+ choice = ChatCompletionResponseStreamChoice (
184
+ index = i ,
185
+ delta = DeltaMessage (role = "assistant" , content = "" , reasoning_content = "" , tool_calls = None )
186
+ )
187
+ if request .metadata is not None and request .metadata .get ("training" , False ):
188
+ choice .delta .token_ids = prompt_token_ids
189
+ chunk = ChatCompletionStreamResponse (
190
+ id = request_id ,
191
+ object = chunk_object_type ,
192
+ created = created_time ,
193
+ choices = [choice ],
194
+ model = model_name
195
+ )
196
+ if include_continuous_usage :
197
+ chunk .usage = UsageInfo (
198
+ prompt_tokens = num_prompt_tokens ,
199
+ completion_tokens = 0 ,
200
+ total_tokens = num_prompt_tokens ,
201
+ prompt_tokens_details = PromptTokenUsageInfo (cached_tokens = num_cached_tokens )
202
+ )
203
+ yield f"data: { chunk .model_dump_json (exclude_unset = True )} \n \n "
204
+ first_iteration = False
205
+
206
+ output = res ["outputs" ]
207
+ delta_text = output ["text" ]
208
+ raw_top_logprobs = output ["top_logprobs" ]
209
+ logprobs_res = None
210
+ if raw_top_logprobs is not None :
211
+ top_logprobs = LogprobsLists (
212
+ logprob_token_ids = raw_top_logprobs [0 ],
213
+ logprobs = raw_top_logprobs [1 ],
214
+ sampled_token_ranks = raw_top_logprobs [2 ],
184
215
)
185
- if request .metadata is not None and request .metadata .get ("training" , False ):
186
- choice .delta .token_ids = prompt_token_ids
187
- chunk = ChatCompletionStreamResponse (
188
- id = request_id ,
189
- object = chunk_object_type ,
190
- created = created_time ,
191
- choices = [choice ],
192
- model = model_name
216
+ logprobs_res = self .build_logprobs_response (
217
+ request_logprobs = request .logprobs ,
218
+ response_logprobs = top_logprobs ,
219
+ request_top_logprobs = request .top_logprobs ,
193
220
)
194
- if include_continuous_usage :
195
- chunk .usage = UsageInfo (
196
- prompt_tokens = num_prompt_tokens ,
197
- completion_tokens = 0 ,
198
- total_tokens = num_prompt_tokens ,
199
- prompt_tokens_details = PromptTokenUsageInfo (cached_tokens = num_cached_tokens )
200
- )
201
- yield f"data: { chunk .model_dump_json (exclude_unset = True )} \n \n "
202
- first_iteration = False
203
-
204
- output = res ["outputs" ]
205
- delta_text = output ["text" ]
206
- raw_top_logprobs = output ["top_logprobs" ]
207
- logprobs_res = None
208
- if raw_top_logprobs is not None :
209
- top_logprobs = LogprobsLists (
210
- logprob_token_ids = raw_top_logprobs [0 ],
211
- logprobs = raw_top_logprobs [1 ],
212
- sampled_token_ranks = raw_top_logprobs [2 ],
213
- )
214
- logprobs_res = self .build_logprobs_response (
215
- request_logprobs = request .logprobs ,
216
- response_logprobs = top_logprobs ,
217
- request_top_logprobs = request .top_logprobs ,
218
- )
219
221
220
- previous_num_tokens += len (output ["token_ids" ])
221
- delta_message = DeltaMessage (content = delta_text , reasoning_content = output .get ("reasoning_content" ), \
222
- token_ids = output .get ("token_ids" ), tool_calls = output .get ("tool_call_content" , []))
222
+ previous_num_tokens += len (output ["token_ids" ])
223
+ delta_message = DeltaMessage (content = delta_text , reasoning_content = output .get ("reasoning_content" ), \
224
+ token_ids = output .get ("token_ids" ), tool_calls = output .get ("tool_call_content" , []))
223
225
224
- choice = ChatCompletionResponseStreamChoice (
225
- index = 0 ,
226
- delta = delta_message ,
227
- logprobs = logprobs_res ,
228
- arrival_time = arrival_time
229
- )
230
- if res ["finished" ]:
231
- num_choices -= 1
232
- work_process_metrics .e2e_request_latency .observe (time .time () - res ["metrics" ]["request_start_time" ])
233
- has_no_token_limit = request .max_tokens is None and request .max_completion_tokens is None
234
- max_tokens = request .max_completion_tokens or request .max_tokens
235
- if has_no_token_limit or previous_num_tokens != max_tokens :
236
- choice .finish_reason = "stop"
237
- if self .engine_client .reasoning_parser == "ernie_x1" and \
238
- output .get ("finish_reason" , "" ) == "tool_calls" :
239
- choice .finish_reason = "tool_calls"
240
- else :
241
- choice .finish_reason = "length"
242
-
243
- if res .get ("error_msg" ) is not None and "Recover" in res ["error_msg" ]:
244
- choice .finish_reason = "recover_stop"
245
-
246
- if request .metadata is not None and request .metadata .get ("training" , False ) and delta_text != "" :
247
- choice .delta .token_ids = output ["token_ids" ]
248
- if include_continuous_usage :
249
- chunk .usage = UsageInfo (
250
- prompt_tokens = num_prompt_tokens ,
251
- completion_tokens = previous_num_tokens ,
252
- total_tokens = num_prompt_tokens + previous_num_tokens
226
+ choice = ChatCompletionResponseStreamChoice (
227
+ index = 0 ,
228
+ delta = delta_message ,
229
+ logprobs = logprobs_res ,
230
+ arrival_time = arrival_time
253
231
)
254
- choices .append (choice )
232
+ if res ["finished" ]:
233
+ num_choices -= 1
234
+ work_process_metrics .e2e_request_latency .observe (time .time () - res ["metrics" ]["request_start_time" ])
235
+ has_no_token_limit = request .max_tokens is None and request .max_completion_tokens is None
236
+ max_tokens = request .max_completion_tokens or request .max_tokens
237
+ if has_no_token_limit or previous_num_tokens != max_tokens :
238
+ choice .finish_reason = "stop"
239
+ if self .engine_client .reasoning_parser == "ernie_x1" and \
240
+ output .get ("finish_reason" , "" ) == "tool_calls" :
241
+ choice .finish_reason = "tool_calls"
242
+ else :
243
+ choice .finish_reason = "length"
244
+
245
+ if res .get ("error_msg" ) is not None and "Recover" in res ["error_msg" ]:
246
+ choice .finish_reason = "recover_stop"
247
+
248
+ if request .metadata is not None and request .metadata .get ("training" , False ) and delta_text != "" :
249
+ choice .delta .token_ids = output ["token_ids" ]
250
+ if include_continuous_usage :
251
+ chunk .usage = UsageInfo (
252
+ prompt_tokens = num_prompt_tokens ,
253
+ completion_tokens = previous_num_tokens ,
254
+ total_tokens = num_prompt_tokens + previous_num_tokens
255
+ )
256
+ choices .append (choice )
257
+
258
+ if len (choices ) == max_streaming_response_tokens or res ["finished" ]:
259
+ chunk .choices = choices
260
+ yield f"data: { chunk .model_dump_json (exclude_unset = True )} \n \n "
261
+ choices = []
255
262
256
- if len ( choices ) == max_streaming_response_tokens or res [ "finished" ] :
263
+ if choices :
257
264
chunk .choices = choices
258
265
yield f"data: { chunk .model_dump_json (exclude_unset = True )} \n \n "
259
266
choices = []
@@ -321,33 +328,38 @@ async def chat_completion_full_generator(
321
328
await asyncio .sleep (0.1 )
322
329
continue
323
330
324
- data = json .loads (raw_data [- 1 ].decode ('utf-8' ))
325
- if data .get ("error_code" , 200 ) != 200 :
326
- raise ValueError ("{}" .format (data ["error_msg" ]))
327
- if request .metadata is not None :
328
- enable_thinking = request .metadata .get ("enable_thinking" )
329
- data = self .engine_client .data_processor .process_response_dict (
330
- data , stream = False , enable_thinking = enable_thinking )
331
- # api_server_logger.debug(f"Client {request_id} received: {data}")
332
- previous_num_tokens += len (data ["outputs" ]["token_ids" ])
333
- # The logprob for handling the response
334
- output = data ["outputs" ]
335
- raw_top_logprobs = output ["top_logprobs" ]
336
- if raw_top_logprobs is not None :
337
- top_logprobs = LogprobsLists (
338
- logprob_token_ids = raw_top_logprobs [0 ],
339
- logprobs = raw_top_logprobs [1 ],
340
- sampled_token_ranks = raw_top_logprobs [2 ],
341
- )
342
- logprobs_res = self .build_logprobs_response (
343
- request_logprobs = request .logprobs ,
344
- response_logprobs = top_logprobs ,
345
- request_top_logprobs = request .top_logprobs ,
346
- )
347
- if logprobs_res and logprobs_res .content is not None :
348
- logprob_contents .extend (logprobs_res .content )
349
- if data ["finished" ]:
350
- final_res = data
331
+ response = msgpack .unpackb (raw_data [- 1 ])
332
+ task_is_finished = False
333
+ for data in response :
334
+ if data .get ("error_code" , 200 ) != 200 :
335
+ raise ValueError ("{}" .format (data ["error_msg" ]))
336
+ if request .metadata is not None :
337
+ enable_thinking = request .metadata .get ("enable_thinking" )
338
+ data = self .engine_client .data_processor .process_response_dict (
339
+ data , stream = False , enable_thinking = enable_thinking )
340
+ # api_server_logger.debug(f"Client {request_id} received: {data}")
341
+ previous_num_tokens += len (data ["outputs" ]["token_ids" ])
342
+ # The logprob for handling the response
343
+ output = data ["outputs" ]
344
+ raw_top_logprobs = output ["top_logprobs" ]
345
+ if raw_top_logprobs is not None :
346
+ top_logprobs = LogprobsLists (
347
+ logprob_token_ids = raw_top_logprobs [0 ],
348
+ logprobs = raw_top_logprobs [1 ],
349
+ sampled_token_ranks = raw_top_logprobs [2 ],
350
+ )
351
+ logprobs_res = self .build_logprobs_response (
352
+ request_logprobs = request .logprobs ,
353
+ response_logprobs = top_logprobs ,
354
+ request_top_logprobs = request .top_logprobs ,
355
+ )
356
+ if logprobs_res and logprobs_res .content is not None :
357
+ logprob_contents .extend (logprobs_res .content )
358
+ if data ["finished" ]:
359
+ final_res = data
360
+ task_is_finished = True
361
+ break
362
+ if task_is_finished :
351
363
break
352
364
finally :
353
365
dealer .close ()
0 commit comments