Skip to content

Commit 8e807cd

Browse files
authored
[Misc] feat output content in stream response (#19608)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
1 parent e601efc commit 8e807cd

File tree

1 file changed

+148
-2
lines changed

1 file changed

+148
-2
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Annotated, Any, Optional
2222

2323
import prometheus_client
24+
import pydantic
2425
import regex as re
2526
import uvloop
2627
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
@@ -1203,6 +1204,142 @@ async def send_with_request_id(message: Message) -> None:
12031204
return self.app(scope, receive, send_with_request_id)
12041205

12051206

1207+
def _extract_content_from_chunk(chunk_data: dict) -> str:
1208+
"""Extract content from a streaming response chunk."""
1209+
try:
1210+
from vllm.entrypoints.openai.protocol import (
1211+
ChatCompletionStreamResponse, CompletionStreamResponse)
1212+
1213+
# Try using Completion types for type-safe parsing
1214+
if chunk_data.get('object') == 'chat.completion.chunk':
1215+
chat_response = ChatCompletionStreamResponse.model_validate(
1216+
chunk_data)
1217+
if chat_response.choices and chat_response.choices[0].delta.content:
1218+
return chat_response.choices[0].delta.content
1219+
elif chunk_data.get('object') == 'text_completion':
1220+
completion_response = CompletionStreamResponse.model_validate(
1221+
chunk_data)
1222+
if completion_response.choices and completion_response.choices[
1223+
0].text:
1224+
return completion_response.choices[0].text
1225+
except pydantic.ValidationError:
1226+
# Fallback to manual parsing
1227+
if 'choices' in chunk_data and chunk_data['choices']:
1228+
choice = chunk_data['choices'][0]
1229+
if 'delta' in choice and choice['delta'].get('content'):
1230+
return choice['delta']['content']
1231+
elif choice.get('text'):
1232+
return choice['text']
1233+
return ""
1234+
1235+
1236+
class SSEDecoder:
1237+
"""Robust Server-Sent Events decoder for streaming responses."""
1238+
1239+
def __init__(self):
1240+
self.buffer = ""
1241+
self.content_buffer = []
1242+
1243+
def decode_chunk(self, chunk: bytes) -> list[dict]:
1244+
"""Decode a chunk of SSE data and return parsed events."""
1245+
import json
1246+
1247+
try:
1248+
chunk_str = chunk.decode('utf-8')
1249+
except UnicodeDecodeError:
1250+
# Skip malformed chunks
1251+
return []
1252+
1253+
self.buffer += chunk_str
1254+
events = []
1255+
1256+
# Process complete lines
1257+
while '\n' in self.buffer:
1258+
line, self.buffer = self.buffer.split('\n', 1)
1259+
line = line.rstrip('\r') # Handle CRLF
1260+
1261+
if line.startswith('data: '):
1262+
data_str = line[6:].strip()
1263+
if data_str == '[DONE]':
1264+
events.append({'type': 'done'})
1265+
elif data_str:
1266+
try:
1267+
event_data = json.loads(data_str)
1268+
events.append({'type': 'data', 'data': event_data})
1269+
except json.JSONDecodeError:
1270+
# Skip malformed JSON
1271+
continue
1272+
1273+
return events
1274+
1275+
def extract_content(self, event_data: dict) -> str:
1276+
"""Extract content from event data."""
1277+
return _extract_content_from_chunk(event_data)
1278+
1279+
def add_content(self, content: str) -> None:
1280+
"""Add content to the buffer."""
1281+
if content:
1282+
self.content_buffer.append(content)
1283+
1284+
def get_complete_content(self) -> str:
1285+
"""Get the complete buffered content."""
1286+
return ''.join(self.content_buffer)
1287+
1288+
1289+
def _log_streaming_response(response, response_body: list) -> None:
1290+
"""Log streaming response with robust SSE parsing."""
1291+
from starlette.concurrency import iterate_in_threadpool
1292+
1293+
sse_decoder = SSEDecoder()
1294+
chunk_count = 0
1295+
1296+
def buffered_iterator():
1297+
nonlocal chunk_count
1298+
1299+
for chunk in response_body:
1300+
chunk_count += 1
1301+
yield chunk
1302+
1303+
# Parse SSE events from chunk
1304+
events = sse_decoder.decode_chunk(chunk)
1305+
1306+
for event in events:
1307+
if event['type'] == 'data':
1308+
content = sse_decoder.extract_content(event['data'])
1309+
sse_decoder.add_content(content)
1310+
elif event['type'] == 'done':
1311+
# Log complete content when done
1312+
full_content = sse_decoder.get_complete_content()
1313+
if full_content:
1314+
# Truncate if too long
1315+
if len(full_content) > 2048:
1316+
full_content = full_content[:2048] + ""
1317+
"...[truncated]"
1318+
logger.info(
1319+
"response_body={streaming_complete: " \
1320+
"content='%s', chunks=%d}",
1321+
full_content, chunk_count)
1322+
else:
1323+
logger.info(
1324+
"response_body={streaming_complete: " \
1325+
"no_content, chunks=%d}",
1326+
chunk_count)
1327+
return
1328+
1329+
response.body_iterator = iterate_in_threadpool(buffered_iterator())
1330+
logger.info("response_body={streaming_started: chunks=%d}",
1331+
len(response_body))
1332+
1333+
1334+
def _log_non_streaming_response(response_body: list) -> None:
1335+
"""Log non-streaming response."""
1336+
try:
1337+
decoded_body = response_body[0].decode()
1338+
logger.info("response_body={%s}", decoded_body)
1339+
except UnicodeDecodeError:
1340+
logger.info("response_body={<binary_data>}")
1341+
1342+
12061343
def build_app(args: Namespace) -> FastAPI:
12071344
if args.disable_fastapi_docs:
12081345
app = FastAPI(openapi_url=None,
@@ -1267,8 +1404,17 @@ async def log_response(request: Request, call_next):
12671404
section async for section in response.body_iterator
12681405
]
12691406
response.body_iterator = iterate_in_threadpool(iter(response_body))
1270-
logger.info("response_body={%s}",
1271-
response_body[0].decode() if response_body else None)
1407+
# Check if this is a streaming response by looking at content-type
1408+
content_type = response.headers.get("content-type", "")
1409+
is_streaming = content_type == "text/event-stream; charset=utf-8"
1410+
1411+
# Log response body based on type
1412+
if not response_body:
1413+
logger.info("response_body={<empty>}")
1414+
elif is_streaming:
1415+
_log_streaming_response(response, response_body)
1416+
else:
1417+
_log_non_streaming_response(response_body)
12721418
return response
12731419

12741420
for middleware in args.middleware:

0 commit comments

Comments
 (0)