|
21 | 21 | from typing import Annotated, Any, Optional
|
22 | 22 |
|
23 | 23 | import prometheus_client
|
| 24 | +import pydantic |
24 | 25 | import regex as re
|
25 | 26 | import uvloop
|
26 | 27 | from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
@@ -1203,6 +1204,142 @@ async def send_with_request_id(message: Message) -> None:
|
1203 | 1204 | return self.app(scope, receive, send_with_request_id)
|
1204 | 1205 |
|
1205 | 1206 |
|
| 1207 | +def _extract_content_from_chunk(chunk_data: dict) -> str: |
| 1208 | + """Extract content from a streaming response chunk.""" |
| 1209 | + try: |
| 1210 | + from vllm.entrypoints.openai.protocol import ( |
| 1211 | + ChatCompletionStreamResponse, CompletionStreamResponse) |
| 1212 | + |
| 1213 | + # Try using Completion types for type-safe parsing |
| 1214 | + if chunk_data.get('object') == 'chat.completion.chunk': |
| 1215 | + chat_response = ChatCompletionStreamResponse.model_validate( |
| 1216 | + chunk_data) |
| 1217 | + if chat_response.choices and chat_response.choices[0].delta.content: |
| 1218 | + return chat_response.choices[0].delta.content |
| 1219 | + elif chunk_data.get('object') == 'text_completion': |
| 1220 | + completion_response = CompletionStreamResponse.model_validate( |
| 1221 | + chunk_data) |
| 1222 | + if completion_response.choices and completion_response.choices[ |
| 1223 | + 0].text: |
| 1224 | + return completion_response.choices[0].text |
| 1225 | + except pydantic.ValidationError: |
| 1226 | + # Fallback to manual parsing |
| 1227 | + if 'choices' in chunk_data and chunk_data['choices']: |
| 1228 | + choice = chunk_data['choices'][0] |
| 1229 | + if 'delta' in choice and choice['delta'].get('content'): |
| 1230 | + return choice['delta']['content'] |
| 1231 | + elif choice.get('text'): |
| 1232 | + return choice['text'] |
| 1233 | + return "" |
| 1234 | + |
| 1235 | + |
| 1236 | +class SSEDecoder: |
| 1237 | + """Robust Server-Sent Events decoder for streaming responses.""" |
| 1238 | + |
| 1239 | + def __init__(self): |
| 1240 | + self.buffer = "" |
| 1241 | + self.content_buffer = [] |
| 1242 | + |
| 1243 | + def decode_chunk(self, chunk: bytes) -> list[dict]: |
| 1244 | + """Decode a chunk of SSE data and return parsed events.""" |
| 1245 | + import json |
| 1246 | + |
| 1247 | + try: |
| 1248 | + chunk_str = chunk.decode('utf-8') |
| 1249 | + except UnicodeDecodeError: |
| 1250 | + # Skip malformed chunks |
| 1251 | + return [] |
| 1252 | + |
| 1253 | + self.buffer += chunk_str |
| 1254 | + events = [] |
| 1255 | + |
| 1256 | + # Process complete lines |
| 1257 | + while '\n' in self.buffer: |
| 1258 | + line, self.buffer = self.buffer.split('\n', 1) |
| 1259 | + line = line.rstrip('\r') # Handle CRLF |
| 1260 | + |
| 1261 | + if line.startswith('data: '): |
| 1262 | + data_str = line[6:].strip() |
| 1263 | + if data_str == '[DONE]': |
| 1264 | + events.append({'type': 'done'}) |
| 1265 | + elif data_str: |
| 1266 | + try: |
| 1267 | + event_data = json.loads(data_str) |
| 1268 | + events.append({'type': 'data', 'data': event_data}) |
| 1269 | + except json.JSONDecodeError: |
| 1270 | + # Skip malformed JSON |
| 1271 | + continue |
| 1272 | + |
| 1273 | + return events |
| 1274 | + |
| 1275 | + def extract_content(self, event_data: dict) -> str: |
| 1276 | + """Extract content from event data.""" |
| 1277 | + return _extract_content_from_chunk(event_data) |
| 1278 | + |
| 1279 | + def add_content(self, content: str) -> None: |
| 1280 | + """Add content to the buffer.""" |
| 1281 | + if content: |
| 1282 | + self.content_buffer.append(content) |
| 1283 | + |
| 1284 | + def get_complete_content(self) -> str: |
| 1285 | + """Get the complete buffered content.""" |
| 1286 | + return ''.join(self.content_buffer) |
| 1287 | + |
| 1288 | + |
| 1289 | +def _log_streaming_response(response, response_body: list) -> None: |
| 1290 | + """Log streaming response with robust SSE parsing.""" |
| 1291 | + from starlette.concurrency import iterate_in_threadpool |
| 1292 | + |
| 1293 | + sse_decoder = SSEDecoder() |
| 1294 | + chunk_count = 0 |
| 1295 | + |
| 1296 | + def buffered_iterator(): |
| 1297 | + nonlocal chunk_count |
| 1298 | + |
| 1299 | + for chunk in response_body: |
| 1300 | + chunk_count += 1 |
| 1301 | + yield chunk |
| 1302 | + |
| 1303 | + # Parse SSE events from chunk |
| 1304 | + events = sse_decoder.decode_chunk(chunk) |
| 1305 | + |
| 1306 | + for event in events: |
| 1307 | + if event['type'] == 'data': |
| 1308 | + content = sse_decoder.extract_content(event['data']) |
| 1309 | + sse_decoder.add_content(content) |
| 1310 | + elif event['type'] == 'done': |
| 1311 | + # Log complete content when done |
| 1312 | + full_content = sse_decoder.get_complete_content() |
| 1313 | + if full_content: |
| 1314 | + # Truncate if too long |
| 1315 | + if len(full_content) > 2048: |
| 1316 | + full_content = full_content[:2048] + "" |
| 1317 | + "...[truncated]" |
| 1318 | + logger.info( |
| 1319 | + "response_body={streaming_complete: " \ |
| 1320 | + "content='%s', chunks=%d}", |
| 1321 | + full_content, chunk_count) |
| 1322 | + else: |
| 1323 | + logger.info( |
| 1324 | + "response_body={streaming_complete: " \ |
| 1325 | + "no_content, chunks=%d}", |
| 1326 | + chunk_count) |
| 1327 | + return |
| 1328 | + |
| 1329 | + response.body_iterator = iterate_in_threadpool(buffered_iterator()) |
| 1330 | + logger.info("response_body={streaming_started: chunks=%d}", |
| 1331 | + len(response_body)) |
| 1332 | + |
| 1333 | + |
| 1334 | +def _log_non_streaming_response(response_body: list) -> None: |
| 1335 | + """Log non-streaming response.""" |
| 1336 | + try: |
| 1337 | + decoded_body = response_body[0].decode() |
| 1338 | + logger.info("response_body={%s}", decoded_body) |
| 1339 | + except UnicodeDecodeError: |
| 1340 | + logger.info("response_body={<binary_data>}") |
| 1341 | + |
| 1342 | + |
1206 | 1343 | def build_app(args: Namespace) -> FastAPI:
|
1207 | 1344 | if args.disable_fastapi_docs:
|
1208 | 1345 | app = FastAPI(openapi_url=None,
|
@@ -1267,8 +1404,17 @@ async def log_response(request: Request, call_next):
|
1267 | 1404 | section async for section in response.body_iterator
|
1268 | 1405 | ]
|
1269 | 1406 | response.body_iterator = iterate_in_threadpool(iter(response_body))
|
1270 |
| - logger.info("response_body={%s}", |
1271 |
| - response_body[0].decode() if response_body else None) |
| 1407 | + # Check if this is a streaming response by looking at content-type |
| 1408 | + content_type = response.headers.get("content-type", "") |
| 1409 | + is_streaming = content_type == "text/event-stream; charset=utf-8" |
| 1410 | + |
| 1411 | + # Log response body based on type |
| 1412 | + if not response_body: |
| 1413 | + logger.info("response_body={<empty>}") |
| 1414 | + elif is_streaming: |
| 1415 | + _log_streaming_response(response, response_body) |
| 1416 | + else: |
| 1417 | + _log_non_streaming_response(response_body) |
1272 | 1418 | return response
|
1273 | 1419 |
|
1274 | 1420 | for middleware in args.middleware:
|
|
0 commit comments