Skip to content

Commit d245d1c

Browse files
authored
[LLM] support send batch data and aggregate data (#2860)
* [LLM] support send batch data and aggregate data * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] update
1 parent 63d6e7c commit d245d1c

File tree

11 files changed

+269
-210
lines changed

11 files changed

+269
-210
lines changed

fastdeploy/engine/engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,11 @@ def _zmq_send_generated_tokens(self):
263263
try:
264264
results = self.scheduler.get_results()
265265
if len(results) == 0:
266-
time.sleep(0.001)
266+
time.sleep(0.005)
267+
continue
267268
for request_id, contents in results.items():
268-
for result in contents:
269-
self.zmq_server.send_multipart(request_id, result)
269+
self.zmq_server.send_multipart(request_id, contents)
270+
270271
except Exception as e:
271272
llm_logger.error("Unexcepted error happend: {}, {}".format(
272273
e, str(traceback.format_exc())))

fastdeploy/engine/request.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dataclasses import asdict, dataclass, fields
2121
from typing import Any, Dict, Optional, Union
2222

23-
import numpy
23+
import numpy as np
2424

2525
from fastdeploy.engine.sampling_params import SamplingParams
2626
from fastdeploy.utils import data_processor_logger
@@ -181,7 +181,7 @@ def __repr__(self) -> str:
181181
f"sampling_params={self.sampling_params})")
182182

183183

184-
@dataclass
184+
@dataclass(slots=True)
185185
class CompletionOutput:
186186
"""The output data of one completion output of a request.
187187
@@ -235,7 +235,7 @@ def __repr__(self) -> str:
235235
f"reasoning_content={self.reasoning_content!r}")
236236

237237

238-
@dataclass
238+
@dataclass(slots=True)
239239
class RequestMetrics:
240240
"""Metrics associated with a request.
241241
@@ -310,6 +310,10 @@ class RequestOutput:
310310
None if decoder-only.
311311
num_cached_tokens: The number of tokens with prefix cache hit.
312312
"""
313+
__slots__ = (
314+
'request_id', 'prompt', 'prompt_token_ids', 'outputs',
315+
'finished', 'metrics', 'num_cached_tokens', 'error_code', 'error_msg'
316+
)
313317

314318
def __init__(
315319
self,
@@ -333,6 +337,12 @@ def __init__(
333337
self.error_code = error_code
334338
self.error_msg = error_msg
335339

340+
341+
if prompt_token_ids is None:
342+
self.prompt_token_ids = []
343+
elif isinstance(self.prompt_token_ids, np.ndarray):
344+
self.prompt_token_ids = self.prompt_token_ids.tolist()
345+
336346
def add(self, next_output: "RequestOutput") -> None:
337347
"""Merge RequestOutput into this one"""
338348

@@ -365,11 +375,6 @@ def from_dict(cls, d: dict):
365375

366376
def to_dict(self):
367377
"""convert RequestOutput into a serializable dict """
368-
if self.prompt_token_ids is None:
369-
self.prompt_token_ids = []
370-
371-
if type(self.prompt_token_ids) is numpy.ndarray:
372-
self.prompt_token_ids = self.prompt_token_ids.tolist()
373378

374379
return {
375380
"request_id": self.request_id,

fastdeploy/entrypoints/llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def generate(
169169

170170
# get output
171171
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
172+
for i in range(len(outputs)):
173+
outputs[i].prompt = prompts[i]
172174
return outputs
173175

174176
def chat(

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 127 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import uuid
2222
from typing import List, Optional
2323

24+
import msgpack
2425
import aiozmq
2526
from aiozmq import zmq
2627

@@ -143,6 +144,8 @@ async def chat_completion_stream_generator(
143144
dealer.write([b"", request_id.encode('utf-8')])
144145
choices = []
145146
current_waiting_time = 0
147+
if request.metadata is not None:
148+
enable_thinking = request.metadata.get("enable_thinking")
146149
while num_choices > 0:
147150
try:
148151
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
@@ -158,102 +161,106 @@ async def chat_completion_stream_generator(
158161
raise ValueError(f"Engine is not healthy: {msg}")
159162
else:
160163
current_waiting_time = 0
161-
await asyncio.sleep(0.1)
164+
await asyncio.sleep(0.01)
162165
continue
166+
response = msgpack.unpackb(raw_data[-1])
167+
for res in response:
168+
if res.get("error_code", 200) != 200:
169+
raise ValueError("{}".format(res["error_msg"]))
170+
171+
self.engine_client.data_processor.process_response_dict(
172+
res, stream=True, enable_thinking=enable_thinking)
163173

164-
res = json.loads(raw_data[-1].decode('utf-8'))
165-
if res.get("error_code", 200) != 200:
166-
raise ValueError("{}".format(res["error_msg"]))
167-
if request.metadata is not None:
168-
enable_thinking = request.metadata.get("enable_thinking")
169-
self.engine_client.data_processor.process_response_dict(
170-
res, stream=True, enable_thinking=enable_thinking)
171-
172-
if res['metrics']['first_token_time'] is not None:
173-
arrival_time = res['metrics']['first_token_time']
174-
inference_start_time = res['metrics']['inference_start_time']
175-
else:
176-
arrival_time = res['metrics']['arrival_time'] - inference_start_time
177-
if first_iteration:
178-
num_prompt_tokens = len(prompt_token_ids)
179-
num_cached_tokens = res.get("num_cached_tokens", 0)
180-
for i in range(num_choices):
181-
choice = ChatCompletionResponseStreamChoice(
182-
index=i,
183-
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None)
174+
if res['metrics']['first_token_time'] is not None:
175+
arrival_time = res['metrics']['first_token_time']
176+
inference_start_time = res['metrics']['inference_start_time']
177+
else:
178+
arrival_time = res['metrics']['arrival_time'] - inference_start_time
179+
if first_iteration:
180+
num_prompt_tokens = len(prompt_token_ids)
181+
num_cached_tokens = res.get("num_cached_tokens", 0)
182+
for i in range(num_choices):
183+
choice = ChatCompletionResponseStreamChoice(
184+
index=i,
185+
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None)
186+
)
187+
if request.metadata is not None and request.metadata.get("training", False):
188+
choice.delta.token_ids = prompt_token_ids
189+
chunk = ChatCompletionStreamResponse(
190+
id=request_id,
191+
object=chunk_object_type,
192+
created=created_time,
193+
choices=[choice],
194+
model=model_name
195+
)
196+
if include_continuous_usage:
197+
chunk.usage = UsageInfo(
198+
prompt_tokens=num_prompt_tokens,
199+
completion_tokens=0,
200+
total_tokens=num_prompt_tokens,
201+
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
202+
)
203+
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
204+
first_iteration = False
205+
206+
output = res["outputs"]
207+
delta_text = output["text"]
208+
raw_top_logprobs = output["top_logprobs"]
209+
logprobs_res = None
210+
if raw_top_logprobs is not None:
211+
top_logprobs = LogprobsLists(
212+
logprob_token_ids=raw_top_logprobs[0],
213+
logprobs=raw_top_logprobs[1],
214+
sampled_token_ranks=raw_top_logprobs[2],
184215
)
185-
if request.metadata is not None and request.metadata.get("training", False):
186-
choice.delta.token_ids = prompt_token_ids
187-
chunk = ChatCompletionStreamResponse(
188-
id=request_id,
189-
object=chunk_object_type,
190-
created=created_time,
191-
choices=[choice],
192-
model=model_name
216+
logprobs_res = self.build_logprobs_response(
217+
request_logprobs=request.logprobs,
218+
response_logprobs=top_logprobs,
219+
request_top_logprobs=request.top_logprobs,
193220
)
194-
if include_continuous_usage:
195-
chunk.usage = UsageInfo(
196-
prompt_tokens=num_prompt_tokens,
197-
completion_tokens=0,
198-
total_tokens=num_prompt_tokens,
199-
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
200-
)
201-
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
202-
first_iteration = False
203-
204-
output = res["outputs"]
205-
delta_text = output["text"]
206-
raw_top_logprobs = output["top_logprobs"]
207-
logprobs_res = None
208-
if raw_top_logprobs is not None:
209-
top_logprobs = LogprobsLists(
210-
logprob_token_ids=raw_top_logprobs[0],
211-
logprobs=raw_top_logprobs[1],
212-
sampled_token_ranks=raw_top_logprobs[2],
213-
)
214-
logprobs_res = self.build_logprobs_response(
215-
request_logprobs=request.logprobs,
216-
response_logprobs=top_logprobs,
217-
request_top_logprobs=request.top_logprobs,
218-
)
219221

220-
previous_num_tokens += len(output["token_ids"])
221-
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
222-
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", []))
222+
previous_num_tokens += len(output["token_ids"])
223+
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
224+
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", []))
223225

224-
choice = ChatCompletionResponseStreamChoice(
225-
index=0,
226-
delta=delta_message,
227-
logprobs=logprobs_res,
228-
arrival_time=arrival_time
229-
)
230-
if res["finished"]:
231-
num_choices -= 1
232-
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
233-
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
234-
max_tokens = request.max_completion_tokens or request.max_tokens
235-
if has_no_token_limit or previous_num_tokens != max_tokens:
236-
choice.finish_reason = "stop"
237-
if self.engine_client.reasoning_parser == "ernie_x1" and \
238-
output.get("finish_reason", "") == "tool_calls":
239-
choice.finish_reason = "tool_calls"
240-
else:
241-
choice.finish_reason = "length"
242-
243-
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
244-
choice.finish_reason = "recover_stop"
245-
246-
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
247-
choice.delta.token_ids = output["token_ids"]
248-
if include_continuous_usage:
249-
chunk.usage = UsageInfo(
250-
prompt_tokens=num_prompt_tokens,
251-
completion_tokens=previous_num_tokens,
252-
total_tokens=num_prompt_tokens + previous_num_tokens
226+
choice = ChatCompletionResponseStreamChoice(
227+
index=0,
228+
delta=delta_message,
229+
logprobs=logprobs_res,
230+
arrival_time=arrival_time
253231
)
254-
choices.append(choice)
232+
if res["finished"]:
233+
num_choices -= 1
234+
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
235+
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
236+
max_tokens = request.max_completion_tokens or request.max_tokens
237+
if has_no_token_limit or previous_num_tokens != max_tokens:
238+
choice.finish_reason = "stop"
239+
if self.engine_client.reasoning_parser == "ernie_x1" and \
240+
output.get("finish_reason", "") == "tool_calls":
241+
choice.finish_reason = "tool_calls"
242+
else:
243+
choice.finish_reason = "length"
244+
245+
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
246+
choice.finish_reason = "recover_stop"
247+
248+
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
249+
choice.delta.token_ids = output["token_ids"]
250+
if include_continuous_usage:
251+
chunk.usage = UsageInfo(
252+
prompt_tokens=num_prompt_tokens,
253+
completion_tokens=previous_num_tokens,
254+
total_tokens=num_prompt_tokens + previous_num_tokens
255+
)
256+
choices.append(choice)
257+
258+
if len(choices) == max_streaming_response_tokens or res["finished"]:
259+
chunk.choices = choices
260+
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
261+
choices = []
255262

256-
if len(choices) == max_streaming_response_tokens or res["finished"]:
263+
if choices:
257264
chunk.choices = choices
258265
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
259266
choices = []
@@ -321,33 +328,38 @@ async def chat_completion_full_generator(
321328
await asyncio.sleep(0.1)
322329
continue
323330

324-
data = json.loads(raw_data[-1].decode('utf-8'))
325-
if data.get("error_code", 200) != 200:
326-
raise ValueError("{}".format(data["error_msg"]))
327-
if request.metadata is not None:
328-
enable_thinking = request.metadata.get("enable_thinking")
329-
data = self.engine_client.data_processor.process_response_dict(
330-
data, stream=False, enable_thinking=enable_thinking)
331-
# api_server_logger.debug(f"Client {request_id} received: {data}")
332-
previous_num_tokens += len(data["outputs"]["token_ids"])
333-
# The logprob for handling the response
334-
output = data["outputs"]
335-
raw_top_logprobs = output["top_logprobs"]
336-
if raw_top_logprobs is not None:
337-
top_logprobs = LogprobsLists(
338-
logprob_token_ids=raw_top_logprobs[0],
339-
logprobs=raw_top_logprobs[1],
340-
sampled_token_ranks=raw_top_logprobs[2],
341-
)
342-
logprobs_res = self.build_logprobs_response(
343-
request_logprobs=request.logprobs,
344-
response_logprobs=top_logprobs,
345-
request_top_logprobs=request.top_logprobs,
346-
)
347-
if logprobs_res and logprobs_res.content is not None:
348-
logprob_contents.extend(logprobs_res.content)
349-
if data["finished"]:
350-
final_res = data
331+
response = msgpack.unpackb(raw_data[-1])
332+
task_is_finished = False
333+
for data in response:
334+
if data.get("error_code", 200) != 200:
335+
raise ValueError("{}".format(data["error_msg"]))
336+
if request.metadata is not None:
337+
enable_thinking = request.metadata.get("enable_thinking")
338+
data = self.engine_client.data_processor.process_response_dict(
339+
data, stream=False, enable_thinking=enable_thinking)
340+
# api_server_logger.debug(f"Client {request_id} received: {data}")
341+
previous_num_tokens += len(data["outputs"]["token_ids"])
342+
# The logprob for handling the response
343+
output = data["outputs"]
344+
raw_top_logprobs = output["top_logprobs"]
345+
if raw_top_logprobs is not None:
346+
top_logprobs = LogprobsLists(
347+
logprob_token_ids=raw_top_logprobs[0],
348+
logprobs=raw_top_logprobs[1],
349+
sampled_token_ranks=raw_top_logprobs[2],
350+
)
351+
logprobs_res = self.build_logprobs_response(
352+
request_logprobs=request.logprobs,
353+
response_logprobs=top_logprobs,
354+
request_top_logprobs=request.top_logprobs,
355+
)
356+
if logprobs_res and logprobs_res.content is not None:
357+
logprob_contents.extend(logprobs_res.content)
358+
if data["finished"]:
359+
final_res = data
360+
task_is_finished = True
361+
break
362+
if task_is_finished:
351363
break
352364
finally:
353365
dealer.close()

0 commit comments

Comments
 (0)