From f2263b76652c3bf166419a69f8f942337b4d4fb3 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Thu, 10 Jul 2025 00:26:41 +0400 Subject: [PATCH 01/13] feat: Add --enable-log-outputs flag for logging model generations Add optional output logging functionality to complement existing input logging. By default, vLLM only logs incoming requests but not model outputs. This feature adds comprehensive output logging controlled by a new CLI flag. Key features: - New --enable-log-outputs CLI flag (disabled by default) - Logs both streaming and non-streaming responses - Supports individual token deltas in streaming mode - Handles tool calls and function arguments - Respects existing --max-log-len truncation settings - Maintains full backward compatibility Implementation: - Added RequestLogger.log_outputs() method for output logging - Enhanced OpenAIServingChat with output logging in both generators - Enhanced OpenAIServingResponses with output logging support - Added comprehensive test coverage for all scenarios Usage: python -m vllm.entrypoints.openai.api_server --model MODEL_NAME --enable-log-outputs Docker: docker run --gpus all -p 8000:8000 vllm/vllm-openai:latest --model MODEL_NAME --enable-log-outputs This addresses the common need for debugging and monitoring model outputs while preserving the existing behavior by default. Signed-off-by: Adrian Garcia --- tests/test_logger.py | 202 ++++++++++++++++++- vllm/entrypoints/logger.py | 26 +++ vllm/entrypoints/openai/api_server.py | 2 + vllm/entrypoints/openai/cli_args.py | 8 + vllm/entrypoints/openai/serving_chat.py | 62 ++++++ vllm/entrypoints/openai/serving_responses.py | 20 ++ 6 files changed, 319 insertions(+), 1 deletion(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index 8f235f1474f..6e511522c16 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -10,11 +10,12 @@ from json.decoder import JSONDecodeError from tempfile import NamedTemporaryFile from typing import Any -from unittest.mock import patch +from unittest.mock import patch, MagicMock from uuid import uuid4 import pytest +from vllm.entrypoints.logger import RequestLogger from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, enable_trace_function_call, init_logger) from vllm.logging_utils import NewLineFormatter @@ -253,3 +254,202 @@ class CustomClass: assert (prepare_object_to_dump(CustomClass( 1, 'b')) == "CustomClass(a=1, b='b')") + + +def test_request_logger_log_outputs(): + """Test the new log_outputs functionality.""" + # Create a mock logger to capture log calls + mock_logger = MagicMock() + + with patch('vllm.entrypoints.logger.logger', mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test basic output logging + request_logger.log_outputs( + request_id="test-123", + outputs="Hello, world!", + output_token_ids=[1, 2, 3, 4], + finish_reason="stop", + is_streaming=False, + delta=False + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0] + assert "Generated response test-123" in call_args[0] + assert "Hello, world!" in call_args[1] + assert [1, 2, 3, 4] == call_args[2] + assert "stop" == call_args[3] + + +def test_request_logger_log_outputs_streaming_delta(): + """Test log_outputs with streaming delta mode.""" + mock_logger = MagicMock() + + with patch('vllm.entrypoints.logger.logger', mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test streaming delta logging + request_logger.log_outputs( + request_id="test-456", + outputs="Hello", + output_token_ids=[1], + finish_reason=None, + is_streaming=True, + delta=True + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0] + assert "Generated response test-456 (streaming delta)" in call_args[0] + assert "Hello" == call_args[1] + assert [1] == call_args[2] + assert call_args[3] is None + + +def test_request_logger_log_outputs_streaming_complete(): + """Test log_outputs with streaming complete mode.""" + mock_logger = MagicMock() + + with patch('vllm.entrypoints.logger.logger', mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test streaming complete logging + request_logger.log_outputs( + request_id="test-789", + outputs="Complete response", + output_token_ids=[1, 2, 3], + finish_reason="length", + is_streaming=True, + delta=False + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0] + assert "Generated response test-789 (streaming complete)" in call_args[0] + assert "Complete response" == call_args[1] + assert [1, 2, 3] == call_args[2] + assert "length" == call_args[3] + + +def test_request_logger_log_outputs_with_truncation(): + """Test log_outputs respects max_log_len setting.""" + mock_logger = MagicMock() + + with patch('vllm.entrypoints.logger.logger', mock_logger): + # Set max_log_len to 10 + request_logger = RequestLogger(max_log_len=10) + + # Test output truncation + long_output = "This is a very long output that should be truncated" + long_token_ids = list(range(20)) # 20 tokens + + request_logger.log_outputs( + request_id="test-truncate", + outputs=long_output, + output_token_ids=long_token_ids, + finish_reason="stop", + is_streaming=False, + delta=False + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args + + # Check that output was truncated to first 10 characters + logged_output = call_args[0][1] + assert logged_output == "This is a " + assert len(logged_output) == 10 + + # Check that token IDs were truncated to first 10 tokens + logged_token_ids = call_args[0][2] + assert logged_token_ids == list(range(10)) + assert len(logged_token_ids) == 10 + + +def test_request_logger_log_outputs_none_values(): + """Test log_outputs handles None values correctly.""" + mock_logger = MagicMock() + + with patch('vllm.entrypoints.logger.logger', mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test with None output_token_ids + request_logger.log_outputs( + request_id="test-none", + outputs="Test output", + output_token_ids=None, + finish_reason="stop", + is_streaming=False, + delta=False + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0] + assert "Generated response test-none" in call_args[0] + assert "Test output" == call_args[1] + assert call_args[2] is None + assert "stop" == call_args[3] + + +def test_request_logger_log_outputs_empty_output(): + """Test log_outputs handles empty output correctly.""" + mock_logger = MagicMock() + + with patch('vllm.entrypoints.logger.logger', mock_logger): + request_logger = RequestLogger(max_log_len=5) + + # Test with empty output + request_logger.log_outputs( + request_id="test-empty", + outputs="", + output_token_ids=[], + finish_reason="stop", + is_streaming=False, + delta=False + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0] + assert "Generated response test-empty" in call_args[0] + assert "" == call_args[1] + assert [] == call_args[2] + assert "stop" == call_args[3] + + +def test_request_logger_log_outputs_integration(): + """Test that log_outputs can be called alongside log_inputs.""" + mock_logger = MagicMock() + + with patch('vllm.entrypoints.logger.logger', mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test that both methods can be called without interference + request_logger.log_inputs( + request_id="test-integration", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_embeds=None, + params=None, + lora_request=None, + prompt_adapter_request=None + ) + + request_logger.log_outputs( + request_id="test-integration", + outputs="Test output", + output_token_ids=[4, 5, 6], + finish_reason="stop", + is_streaming=False, + delta=False + ) + + # Should have been called twice - once for inputs, once for outputs + assert mock_logger.info.call_count == 2 + + # Check that the calls were made with correct patterns + input_call = mock_logger.info.call_args_list[0][0] + output_call = mock_logger.info.call_args_list[1][0] + + assert "Received request test-integration" in input_call[0] + assert "Generated response test-integration" in output_call[0] diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index f3aee188dae..e2ea6c174fb 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -48,3 +48,29 @@ def log_inputs( prompt, params, prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, lora_request, prompt_adapter_request) + + def log_outputs( + self, + request_id: str, + outputs: str, + output_token_ids: Optional[list[int]], + finish_reason: Optional[str] = None, + is_streaming: bool = False, + delta: bool = False, + ) -> None: + max_log_len = self.max_log_len + if max_log_len is not None: + if outputs is not None: + outputs = outputs[:max_log_len] + + if output_token_ids is not None: + output_token_ids = output_token_ids[:max_log_len] + + stream_info = "" + if is_streaming: + stream_info = " (streaming delta)" if delta else " (streaming complete)" + + logger.info( + "Generated response %s%s: output: %r, " + "output_token_ids: %s, finish_reason: %s", + request_id, stream_info, outputs, output_token_ids, finish_reason) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2f8b31c8a7b..0bec3492eb3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1504,6 +1504,7 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, ) if model_config.runner_type == "generate" else None state.openai_serving_chat = OpenAIServingChat( engine_client, @@ -1521,6 +1522,7 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, ) if model_config.runner_type == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 4f8aaab772f..b05efa6fd0f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -295,6 +295,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "If set to True, enable tracking server_load_metrics in the app state." ) + parser.add_argument( + "--enable-log-outputs", + action='store_true', + default=False, + help= + "If set to True, enable logging of model outputs (generations) " + "in addition to the input logging that is enabled by default." + ) return parser diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 451241d3f9f..5bd22af8b21 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -67,6 +67,7 @@ def __init__( tool_parser: Optional[str] = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + enable_log_outputs: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, @@ -78,6 +79,7 @@ def __init__( self.response_role = response_role self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.enable_log_outputs = enable_log_outputs # set up tool use self.enable_auto_tools: bool = enable_auto_tools @@ -823,6 +825,24 @@ async def chat_completion_stream_generator( if delta_message is None: continue + # Log individual streaming delta if output logging is enabled + if self.enable_log_outputs and self.request_logger: + delta_content = "" + if delta_message.content: + delta_content = delta_message.content + elif delta_message.tool_calls and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments: + delta_content = delta_message.tool_calls[0].function.arguments + + if delta_content: + self.request_logger.log_outputs( + request_id=request_id, + outputs=delta_content, + output_token_ids=list(output.token_ids), + finish_reason=output.finish_reason, + is_streaming=True, + delta=True, + ) + if output.finish_reason is None: # Send token-by-token response for each request.n choice_data = ChatCompletionResponseStreamChoice( @@ -943,6 +963,19 @@ async def chat_completion_stream_generator( completion_tokens=num_completion_tokens, total_tokens=num_prompt_tokens + num_completion_tokens) + # Log complete streaming response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + # Collect all generated text from the SSE decoder if available + # For now, we'll log the completion tokens count as final output + self.request_logger.log_outputs( + request_id=request_id, + outputs=f"", + output_token_ids=None, + finish_reason="streaming_complete", + is_streaming=True, + delta=False, + ) + except Exception as e: # TODO: Use a vllm-specific Validation Error logger.exception("Error in chat completion stream generator.") @@ -1156,6 +1189,35 @@ async def chat_completion_full_generator( kv_transfer_params=final_res.kv_transfer_params, ) + # Log complete response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + for choice in choices: + output_text = "" + if choice.message.content: + output_text = choice.message.content + elif choice.message.tool_calls: + # For tool calls, log the function name and arguments + tool_call_info = [] + for tool_call in choice.message.tool_calls: + if hasattr(tool_call.function, 'name') and hasattr(tool_call.function, 'arguments'): + tool_call_info.append(f"{tool_call.function.name}({tool_call.function.arguments})") + output_text = f"[tool_calls: {', '.join(tool_call_info)}]" + + if output_text: + # Get the corresponding output token IDs + output_token_ids = None + if choice.index < len(final_res.outputs): + output_token_ids = final_res.outputs[choice.index].token_ids + + self.request_logger.log_outputs( + request_id=request_id, + outputs=output_text, + output_token_ids=output_token_ids, + finish_reason=choice.finish_reason, + is_streaming=False, + delta=False, + ) + return response def _get_top_logprobs( diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index ac2b3dfafec..31f55a8868a 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -55,6 +55,7 @@ def __init__( tool_parser: Optional[str] = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + enable_log_outputs: bool = False, ) -> None: super().__init__( engine_client=engine_client, @@ -67,6 +68,7 @@ def __init__( self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.enable_log_outputs = enable_log_outputs self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = None @@ -335,6 +337,24 @@ async def responses_full_generator( usage=usage, ) + # Log complete response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + output_text = "" + if content: + output_text = content + elif reasoning_content: + output_text = f"[reasoning: {reasoning_content}]" + + if output_text: + self.request_logger.log_outputs( + request_id=request.request_id, + outputs=output_text, + output_token_ids=final_output.token_ids, + finish_reason=final_output.finish_reason, + is_streaming=False, + delta=False, + ) + if request.store: async with self.response_store_lock: stored_response = self.response_store.get(response.id) From 4848f49417184b1c3ea1d976fdf14679d3dee62c Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Thu, 10 Jul 2025 12:34:01 +0400 Subject: [PATCH 02/13] fix: Resolve type checking issues in output logging Fix type annotation and variable naming issues identified by mypy: - Change output_token_ids parameter type from list[int] to Sequence[int] to handle compatibility with different sequence types from output objects - Fix variable naming conflict in tool call logging (tool_call_info -> tool_call_descriptions) - Add proper type conversion in log_outputs method for truncation - Update test imports to include Sequence type These fixes ensure the output logging feature passes mypy type checking while maintaining full functionality and backward compatibility. Signed-off-by: Adrian Garcia --- tests/test_logger.py | 2 +- vllm/entrypoints/logger.py | 7 ++++--- vllm/entrypoints/openai/serving_chat.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index 6e511522c16..b15c3f585d2 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from json.decoder import JSONDecodeError from tempfile import NamedTemporaryFile -from typing import Any +from typing import Any, Sequence from unittest.mock import patch, MagicMock from uuid import uuid4 diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index e2ea6c174fb..99d0a4bc674 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Optional, Union, Sequence import torch @@ -53,7 +53,7 @@ def log_outputs( self, request_id: str, outputs: str, - output_token_ids: Optional[list[int]], + output_token_ids: Optional[Sequence[int]], finish_reason: Optional[str] = None, is_streaming: bool = False, delta: bool = False, @@ -64,7 +64,8 @@ def log_outputs( outputs = outputs[:max_log_len] if output_token_ids is not None: - output_token_ids = output_token_ids[:max_log_len] + # Convert to list and apply truncation + output_token_ids = list(output_token_ids)[:max_log_len] stream_info = "" if is_streaming: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5bd22af8b21..bdc4d052435 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1197,11 +1197,11 @@ async def chat_completion_full_generator( output_text = choice.message.content elif choice.message.tool_calls: # For tool calls, log the function name and arguments - tool_call_info = [] + tool_call_descriptions = [] for tool_call in choice.message.tool_calls: if hasattr(tool_call.function, 'name') and hasattr(tool_call.function, 'arguments'): - tool_call_info.append(f"{tool_call.function.name}({tool_call.function.arguments})") - output_text = f"[tool_calls: {', '.join(tool_call_info)}]" + tool_call_descriptions.append(f"{tool_call.function.name}({tool_call.function.arguments})") + output_text = f"[tool_calls: {', '.join(tool_call_descriptions)}]" if output_text: # Get the corresponding output token IDs From 8c0aa7830065afbbd437f4992480d989f5025cef Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Thu, 10 Jul 2025 13:04:33 +0400 Subject: [PATCH 03/13] Fix line length violations (E501) in logger and serving_chat - Break long conditional expressions into multiple lines - Fix tool call logging lines exceeding 80 characters - Remove trailing whitespace - Maintain code readability and functionality Signed-off-by: Adrian Garcia --- vllm/entrypoints/logger.py | 3 ++- vllm/entrypoints/openai/serving_chat.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 99d0a4bc674..1b87113d12b 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -69,7 +69,8 @@ def log_outputs( stream_info = "" if is_streaming: - stream_info = " (streaming delta)" if delta else " (streaming complete)" + stream_info = (" (streaming delta)" if delta else + " (streaming complete)") logger.info( "Generated response %s%s: output: %r, " diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index bdc4d052435..2af576a2464 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -830,8 +830,11 @@ async def chat_completion_stream_generator( delta_content = "" if delta_message.content: delta_content = delta_message.content - elif delta_message.tool_calls and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments: - delta_content = delta_message.tool_calls[0].function.arguments + elif (delta_message.tool_calls and + delta_message.tool_calls[0].function and + delta_message.tool_calls[0].function.arguments): + func_args = delta_message.tool_calls[0].function.arguments + delta_content = func_args if delta_content: self.request_logger.log_outputs( @@ -1200,8 +1203,10 @@ async def chat_completion_full_generator( tool_call_descriptions = [] for tool_call in choice.message.tool_calls: if hasattr(tool_call.function, 'name') and hasattr(tool_call.function, 'arguments'): - tool_call_descriptions.append(f"{tool_call.function.name}({tool_call.function.arguments})") - output_text = f"[tool_calls: {', '.join(tool_call_descriptions)}]" + tool_call_descriptions.append( + f"{tool_call.function.name}({tool_call.function.arguments})") + tool_calls_str = ', '.join(tool_call_descriptions) + output_text = f"[tool_calls: {tool_calls_str}]" if output_text: # Get the corresponding output token IDs From 4a104607f97859d211097812313058a85d86a0ad Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Thu, 10 Jul 2025 15:00:40 +0400 Subject: [PATCH 04/13] Fix line length violation in streaming delta comment Shorten comment from 81 to 71 characters to comply with E501 line length limit. The comment 'Log individual streaming delta if output logging is enabled' was shortened to 'Log streaming delta if output logging is enabled' while maintaining clarity and meaning. Signed-off-by: Adrian Garcia --- vllm/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2af576a2464..ba75dbc6b65 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -825,7 +825,7 @@ async def chat_completion_stream_generator( if delta_message is None: continue - # Log individual streaming delta if output logging is enabled + # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: delta_content = "" if delta_message.content: From 7d35afbcc24412eecf56f2308f55b0081fd3bc24 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Fri, 11 Jul 2025 10:21:05 +0400 Subject: [PATCH 05/13] Run isort and ruff locally to fix pre commit hooks issue Signed-off-by: Adrian Garcia --- tests/test_logger.py | 163 ++--- vllm/entrypoints/logger.py | 27 +- vllm/entrypoints/openai/api_server.py | 618 ++++++++++--------- vllm/entrypoints/openai/cli_args.py | 218 ++++--- vllm/entrypoints/openai/serving_chat.py | 466 ++++++++------ vllm/entrypoints/openai/serving_responses.py | 25 +- 6 files changed, 847 insertions(+), 670 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index b15c3f585d2..b8d0a5c2ffd 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -9,8 +9,8 @@ from dataclasses import dataclass from json.decoder import JSONDecodeError from tempfile import NamedTemporaryFile -from typing import Any, Sequence -from unittest.mock import patch, MagicMock +from typing import Any +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest @@ -138,16 +138,19 @@ def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) -@pytest.mark.parametrize("unexpected_config", ( - "Invalid string", - [{ - "version": 1, - "loggers": [] - }], - 0, -)) +@pytest.mark.parametrize( + "unexpected_config", + ( + "Invalid string", + [{ + "version": 1, + "loggers": [] + }], + 0, + ), +) def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( - unexpected_config: Any): + unexpected_config: Any, ): """This test calls _configure_vllm_root_logger again to test custom logging config behavior, however it fails before any change in behavior or configuration occurs.""" @@ -174,14 +177,16 @@ def test_custom_logging_config_is_parsed_and_used_when_provided(): "propagate": False, } }, - "version": 1 + "version": 1, } with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(valid_logging_config)) logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name), patch( - "vllm.logger.dictConfig") as dict_config_mock: + with ( + patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name), + patch("vllm.logger.dictConfig") as dict_config_mock, + ): _configure_vllm_root_logger() dict_config_mock.assert_called_with(valid_logging_config) @@ -197,7 +202,7 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): "handlers": [], } }, - "version": 1 + "version": 1, } with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(valid_logging_config)) @@ -223,21 +228,22 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): def test_prepare_object_to_dump(): - str_obj = 'str' + str_obj = "str" assert prepare_object_to_dump(str_obj) == "'str'" list_obj = [1, 2, 3] - assert prepare_object_to_dump(list_obj) == '[1, 2, 3]' + assert prepare_object_to_dump(list_obj) == "[1, 2, 3]" - dict_obj = {'a': 1, 'b': 'b'} + dict_obj = {"a": 1, "b": "b"} assert prepare_object_to_dump(dict_obj) in [ - "{a: 1, b: 'b'}", "{b: 'b', a: 1}" + "{a: 1, b: 'b'}", + "{b: 'b', a: 1}", ] set_obj = {1, 2, 3} - assert prepare_object_to_dump(set_obj) == '[1, 2, 3]' + assert prepare_object_to_dump(set_obj) == "[1, 2, 3]" - tuple_obj = ('a', 'b', 'c') + tuple_obj = ("a", "b", "c") assert prepare_object_to_dump(tuple_obj) == "['a', 'b', 'c']" class CustomEnum(enum.Enum): @@ -253,17 +259,17 @@ class CustomClass: b: str assert (prepare_object_to_dump(CustomClass( - 1, 'b')) == "CustomClass(a=1, b='b')") + 1, "b")) == "CustomClass(a=1, b='b')") def test_request_logger_log_outputs(): """Test the new log_outputs functionality.""" # Create a mock logger to capture log calls mock_logger = MagicMock() - - with patch('vllm.entrypoints.logger.logger', mock_logger): + + with patch("vllm.entrypoints.logger.logger", mock_logger): request_logger = RequestLogger(max_log_len=None) - + # Test basic output logging request_logger.log_outputs( request_id="test-123", @@ -271,24 +277,24 @@ def test_request_logger_log_outputs(): output_token_ids=[1, 2, 3, 4], finish_reason="stop", is_streaming=False, - delta=False + delta=False, ) - + mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args[0] assert "Generated response test-123" in call_args[0] assert "Hello, world!" in call_args[1] - assert [1, 2, 3, 4] == call_args[2] - assert "stop" == call_args[3] + assert call_args[2] == [1, 2, 3, 4] + assert call_args[3] == "stop" def test_request_logger_log_outputs_streaming_delta(): """Test log_outputs with streaming delta mode.""" mock_logger = MagicMock() - - with patch('vllm.entrypoints.logger.logger', mock_logger): + + with patch("vllm.entrypoints.logger.logger", mock_logger): request_logger = RequestLogger(max_log_len=None) - + # Test streaming delta logging request_logger.log_outputs( request_id="test-456", @@ -296,24 +302,24 @@ def test_request_logger_log_outputs_streaming_delta(): output_token_ids=[1], finish_reason=None, is_streaming=True, - delta=True + delta=True, ) - + mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args[0] assert "Generated response test-456 (streaming delta)" in call_args[0] - assert "Hello" == call_args[1] - assert [1] == call_args[2] + assert call_args[1] == "Hello" + assert call_args[2] == [1] assert call_args[3] is None def test_request_logger_log_outputs_streaming_complete(): """Test log_outputs with streaming complete mode.""" mock_logger = MagicMock() - - with patch('vllm.entrypoints.logger.logger', mock_logger): + + with patch("vllm.entrypoints.logger.logger", mock_logger): request_logger = RequestLogger(max_log_len=None) - + # Test streaming complete logging request_logger.log_outputs( request_id="test-789", @@ -321,46 +327,47 @@ def test_request_logger_log_outputs_streaming_complete(): output_token_ids=[1, 2, 3], finish_reason="length", is_streaming=True, - delta=False + delta=False, ) - + mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args[0] - assert "Generated response test-789 (streaming complete)" in call_args[0] - assert "Complete response" == call_args[1] - assert [1, 2, 3] == call_args[2] - assert "length" == call_args[3] + assert ("Generated response test-789 (streaming complete)" + in call_args[0]) + assert call_args[1] == "Complete response" + assert call_args[2] == [1, 2, 3] + assert call_args[3] == "length" def test_request_logger_log_outputs_with_truncation(): """Test log_outputs respects max_log_len setting.""" mock_logger = MagicMock() - - with patch('vllm.entrypoints.logger.logger', mock_logger): + + with patch("vllm.entrypoints.logger.logger", mock_logger): # Set max_log_len to 10 request_logger = RequestLogger(max_log_len=10) - + # Test output truncation long_output = "This is a very long output that should be truncated" long_token_ids = list(range(20)) # 20 tokens - + request_logger.log_outputs( request_id="test-truncate", outputs=long_output, output_token_ids=long_token_ids, finish_reason="stop", is_streaming=False, - delta=False + delta=False, ) - + mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args - + # Check that output was truncated to first 10 characters logged_output = call_args[0][1] assert logged_output == "This is a " assert len(logged_output) == 10 - + # Check that token IDs were truncated to first 10 tokens logged_token_ids = call_args[0][2] assert logged_token_ids == list(range(10)) @@ -370,10 +377,10 @@ def test_request_logger_log_outputs_with_truncation(): def test_request_logger_log_outputs_none_values(): """Test log_outputs handles None values correctly.""" mock_logger = MagicMock() - - with patch('vllm.entrypoints.logger.logger', mock_logger): + + with patch("vllm.entrypoints.logger.logger", mock_logger): request_logger = RequestLogger(max_log_len=None) - + # Test with None output_token_ids request_logger.log_outputs( request_id="test-none", @@ -381,24 +388,24 @@ def test_request_logger_log_outputs_none_values(): output_token_ids=None, finish_reason="stop", is_streaming=False, - delta=False + delta=False, ) - + mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args[0] assert "Generated response test-none" in call_args[0] - assert "Test output" == call_args[1] + assert call_args[1] == "Test output" assert call_args[2] is None - assert "stop" == call_args[3] + assert call_args[3] == "stop" def test_request_logger_log_outputs_empty_output(): """Test log_outputs handles empty output correctly.""" mock_logger = MagicMock() - - with patch('vllm.entrypoints.logger.logger', mock_logger): + + with patch("vllm.entrypoints.logger.logger", mock_logger): request_logger = RequestLogger(max_log_len=5) - + # Test with empty output request_logger.log_outputs( request_id="test-empty", @@ -406,24 +413,24 @@ def test_request_logger_log_outputs_empty_output(): output_token_ids=[], finish_reason="stop", is_streaming=False, - delta=False + delta=False, ) - + mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args[0] assert "Generated response test-empty" in call_args[0] - assert "" == call_args[1] - assert [] == call_args[2] - assert "stop" == call_args[3] + assert call_args[1] == "" + assert call_args[2] == [] + assert call_args[3] == "stop" def test_request_logger_log_outputs_integration(): """Test that log_outputs can be called alongside log_inputs.""" mock_logger = MagicMock() - - with patch('vllm.entrypoints.logger.logger', mock_logger): + + with patch("vllm.entrypoints.logger.logger", mock_logger): request_logger = RequestLogger(max_log_len=None) - + # Test that both methods can be called without interference request_logger.log_inputs( request_id="test-integration", @@ -432,24 +439,24 @@ def test_request_logger_log_outputs_integration(): prompt_embeds=None, params=None, lora_request=None, - prompt_adapter_request=None + prompt_adapter_request=None, ) - + request_logger.log_outputs( request_id="test-integration", outputs="Test output", output_token_ids=[4, 5, 6], finish_reason="stop", is_streaming=False, - delta=False + delta=False, ) - + # Should have been called twice - once for inputs, once for outputs assert mock_logger.info.call_count == 2 - + # Check that the calls were made with correct patterns input_call = mock_logger.info.call_args_list[0][0] output_call = mock_logger.info.call_args_list[1][0] - + assert "Received request test-integration" in input_call[0] assert "Generated response test-integration" in output_call[0] diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 1b87113d12b..6c2acbacd39 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union, Sequence +from collections.abc import Sequence +from typing import Optional, Union import torch @@ -44,10 +45,15 @@ def log_inputs( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " "prompt_embeds shape: %s, " - "lora_request: %s, prompt_adapter_request: %s.", request_id, - prompt, params, prompt_token_ids, + "lora_request: %s, prompt_adapter_request: %s.", + request_id, + prompt, + params, + prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, - lora_request, prompt_adapter_request) + lora_request, + prompt_adapter_request, + ) def log_outputs( self, @@ -69,10 +75,15 @@ def log_outputs( stream_info = "" if is_streaming: - stream_info = (" (streaming delta)" if delta else - " (streaming complete)") + stream_info = (" (streaming delta)" + if delta else " (streaming complete)") logger.info( "Generated response %s%s: output: %r, " - "output_token_ids: %s, finish_reason: %s", - request_id, stream_info, outputs, output_token_ids, finish_reason) + "output_token_ids: %s, finish_reason: %s", + request_id, + stream_info, + outputs, + output_token_ids, + finish_reason, + ) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0bec3492eb3..1f99c3f2eb8 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -112,7 +112,7 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger('vllm.entrypoints.openai.api_server') +logger = init_logger("vllm.entrypoints.openai.api_server") _running_tasks: set[asyncio.Task] = set() @@ -125,7 +125,7 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: - await asyncio.sleep(10.) + await asyncio.sleep(10.0) await engine_client.do_log_stats() task = asyncio.create_task(_force_log()) @@ -153,7 +153,6 @@ async def build_async_engine_client( args: Namespace, client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: - # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) @@ -190,6 +189,7 @@ async def build_async_engine_client_from_engine_args( "To disable frontend multiprocessing, set VLLM_USE_V1=0.") from vllm.v1.engine.async_llm import AsyncLLM + async_llm: Optional[AsyncLLM] = None client_index = client_config.pop( "client_index") if client_config else 0 @@ -200,7 +200,8 @@ async def build_async_engine_client_from_engine_args( disable_log_requests=engine_args.disable_log_requests, disable_log_stats=engine_args.disable_log_stats, client_addresses=client_config, - client_index=client_index) + client_index=client_index, + ) # Don't keep the dummy data in memory await async_llm.reset_mm_cache() @@ -213,14 +214,14 @@ async def build_async_engine_client_from_engine_args( # V0 AsyncLLM. elif (MQLLMEngineClient.is_unsupported_config(vllm_config) or disable_frontend_multiprocessing): - engine_client: Optional[EngineClient] = None try: engine_client = AsyncLLMEngine.from_vllm_config( vllm_config=vllm_config, usage_context=usage_context, disable_log_requests=engine_args.disable_log_requests, - disable_log_stats=engine_args.disable_log_stats) + disable_log_stats=engine_args.disable_log_stats, + ) yield engine_client finally: if engine_client and hasattr(engine_client, "shutdown"): @@ -234,8 +235,8 @@ async def build_async_engine_client_from_engine_args( # cleaned up upon exit. global prometheus_multiproc_dir prometheus_multiproc_dir = tempfile.TemporaryDirectory() - os.environ[ - "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + os.environ["PROMETHEUS_MULTIPROC_DIR"] = ( + prometheus_multiproc_dir.name) else: logger.warning( "Found PROMETHEUS_MULTIPROC_DIR was set by user. " @@ -259,12 +260,18 @@ async def build_async_engine_client_from_engine_args( # The Process can raise an exception during startup, which may # not actually result in an exitcode being reported. As a result # we use a shared variable to communicate the information. - engine_alive = multiprocessing.Value('b', True, lock=False) + engine_alive = multiprocessing.Value("b", True, lock=False) engine_process = context.Process( target=run_mp_engine, - args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path, - engine_args.disable_log_stats, - engine_args.disable_log_requests, engine_alive)) + args=( + vllm_config, + UsageContext.OPENAI_API_SERVER, + ipc_path, + engine_args.disable_log_stats, + engine_args.disable_log_requests, + engine_alive, + ), + ) engine_process.start() engine_pid = engine_process.pid assert engine_pid is not None, "Engine process failed to start." @@ -289,8 +296,7 @@ def _cleanup_ipc_path(): await mq_engine_client.setup() break except TimeoutError: - if (not engine_process.is_alive() - or not engine_alive.value): + if not engine_process.is_alive() or not engine_alive.value: raise RuntimeError( "Engine process failed to start. See stack " "trace for the root cause.") from None @@ -314,6 +320,7 @@ def _cleanup_ipc_path(): # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess + multiprocess.mark_process_dead(engine_process.pid) @@ -442,7 +449,7 @@ async def get_server_load_metrics(request: Request): # - /v1/rerank # - /v2/rerank return JSONResponse( - content={'server_load': request.app.state.server_load_metrics}) + content={"server_load": request.app.state.server_load_metrics}) @router.get("/ping", response_class=Response) @@ -452,22 +459,24 @@ async def ping(raw_request: Request) -> Response: return await health(raw_request) -@router.post("/tokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_IMPLEMENTED.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/tokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_IMPLEMENTED.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -490,19 +499,21 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): assert_never(generator) -@router.post("/detokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/detokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -538,24 +549,26 @@ async def show_version(): return JSONResponse(content=ver) -@router.post("/v1/responses", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/responses", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation async def create_responses(request: ResponsesRequest, raw_request: Request): handler = responses(raw_request) @@ -603,24 +616,26 @@ async def cancel_responses(response_id: str, raw_request: Request): return JSONResponse(content=response.model_dump()) -@router.post("/v1/chat/completions", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - } - }) +@router.post( + "/v1/chat/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def create_chat_completion(request: ChatCompletionRequest, @@ -642,24 +657,26 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def create_completion(request: CompletionRequest, raw_request: Request): @@ -686,16 +703,18 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/embeddings", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/embeddings", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def create_embedding(request: EmbeddingRequest, raw_request: Request): @@ -715,16 +734,18 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) -@router.post("/pooling", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/pooling", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def create_pooling(request: PoolingRequest, raw_request: Request): @@ -764,16 +785,18 @@ async def create_classify(request: ClassificationRequest, assert_never(generator) -@router.post("/score", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def create_score(request: ScoreRequest, raw_request: Request): @@ -792,16 +815,18 @@ async def create_score(request: ScoreRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/score", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def create_score_v1(request: ScoreRequest, raw_request: Request): @@ -812,23 +837,25 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) -@router.post("/v1/audio/transcriptions", - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNPROCESSABLE_ENTITY.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/audio/transcriptions", + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNPROCESSABLE_ENTITY.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def create_transcriptions(raw_request: Request, @@ -853,23 +880,25 @@ async def create_transcriptions(raw_request: Request, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/audio/translations", - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNPROCESSABLE_ENTITY.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/audio/translations", + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNPROCESSABLE_ENTITY.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def create_translations(request: Annotated[TranslationRequest, @@ -894,16 +923,18 @@ async def create_translations(request: Annotated[TranslationRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation @load_aware_call async def do_rerank(request: RerankRequest, raw_request: Request): @@ -921,16 +952,18 @@ async def do_rerank(request: RerankRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( @@ -941,16 +974,18 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -@router.post("/v2/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v2/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) @@ -1032,19 +1067,21 @@ async def is_sleeping(raw_request: Request): return JSONResponse(content={"is_sleeping": is_sleeping}) -@router.post("/invocations", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }, +) async def invocations(raw_request: Request): """ For SageMaker, routes requests to other handlers based on model `task`. @@ -1052,8 +1089,10 @@ async def invocations(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}") from e + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e task = raw_request.app.state.task @@ -1061,7 +1100,8 @@ async def invocations(raw_request: Request): raise HTTPException( status_code=400, detail=f"Unsupported task: '{task}' for '/invocations'. " - f"Expected one of {set(TASK_HANDLERS.keys())}") + f"Expected one of {set(TASK_HANDLERS.keys())}", + ) handler_config = TASK_HANDLERS[task] if "messages" in body: @@ -1131,8 +1171,11 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: with open(log_config_file) as f: return json.load(f) except Exception as e: - logger.warning("Failed to load log config from file %s: error %s", - log_config_file, e) + logger.warning( + "Failed to load log config from file %s: error %s", + log_config_file, + e, + ) return None @@ -1154,8 +1197,8 @@ def __init__(self, app: ASGIApp, api_token: str) -> None: def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: - if scope["type"] not in ("http", - "websocket") or scope["method"] == "OPTIONS": + if (scope["type"] not in ("http", "websocket") + or scope["method"] == "OPTIONS"): # scope["type"] can be "lifespan" or "startup" for example, # in which case we don't need to do anything return self.app(scope, receive, send) @@ -1163,8 +1206,8 @@ def __call__(self, scope: Scope, receive: Receive, url_path = URL(scope=scope).path.removeprefix(root_path) headers = Headers(scope=scope) # Type narrow to satisfy mypy. - if url_path.startswith("/v1") and headers.get( - "Authorization") != f"Bearer {self.api_token}": + if (url_path.startswith("/v1") and headers.get("Authorization") + != f"Bearer {self.api_token}"): response = JSONResponse(content={"error": "Unauthorized"}, status_code=401) return response(scope, receive, send) @@ -1211,25 +1254,25 @@ def _extract_content_from_chunk(chunk_data: dict) -> str: ChatCompletionStreamResponse, CompletionStreamResponse) # Try using Completion types for type-safe parsing - if chunk_data.get('object') == 'chat.completion.chunk': + if chunk_data.get("object") == "chat.completion.chunk": chat_response = ChatCompletionStreamResponse.model_validate( chunk_data) if chat_response.choices and chat_response.choices[0].delta.content: return chat_response.choices[0].delta.content - elif chunk_data.get('object') == 'text_completion': + elif chunk_data.get("object") == "text_completion": completion_response = CompletionStreamResponse.model_validate( chunk_data) - if completion_response.choices and completion_response.choices[ - 0].text: + if (completion_response.choices + and completion_response.choices[0].text): return completion_response.choices[0].text except pydantic.ValidationError: # Fallback to manual parsing - if 'choices' in chunk_data and chunk_data['choices']: - choice = chunk_data['choices'][0] - if 'delta' in choice and choice['delta'].get('content'): - return choice['delta']['content'] - elif choice.get('text'): - return choice['text'] + if "choices" in chunk_data and chunk_data["choices"]: + choice = chunk_data["choices"][0] + if "delta" in choice and choice["delta"].get("content"): + return choice["delta"]["content"] + elif choice.get("text"): + return choice["text"] return "" @@ -1245,7 +1288,7 @@ def decode_chunk(self, chunk: bytes) -> list[dict]: import json try: - chunk_str = chunk.decode('utf-8') + chunk_str = chunk.decode("utf-8") except UnicodeDecodeError: # Skip malformed chunks return [] @@ -1254,18 +1297,18 @@ def decode_chunk(self, chunk: bytes) -> list[dict]: events = [] # Process complete lines - while '\n' in self.buffer: - line, self.buffer = self.buffer.split('\n', 1) - line = line.rstrip('\r') # Handle CRLF + while "\n" in self.buffer: + line, self.buffer = self.buffer.split("\n", 1) + line = line.rstrip("\r") # Handle CRLF - if line.startswith('data: '): + if line.startswith("data: "): data_str = line[6:].strip() - if data_str == '[DONE]': - events.append({'type': 'done'}) + if data_str == "[DONE]": + events.append({"type": "done"}) elif data_str: try: event_data = json.loads(data_str) - events.append({'type': 'data', 'data': event_data}) + events.append({"type": "data", "data": event_data}) except json.JSONDecodeError: # Skip malformed JSON continue @@ -1283,7 +1326,7 @@ def add_content(self, content: str) -> None: def get_complete_content(self) -> str: """Get the complete buffered content.""" - return ''.join(self.content_buffer) + return "".join(self.content_buffer) def _log_streaming_response(response, response_body: list) -> None: @@ -1304,10 +1347,10 @@ def buffered_iterator(): events = sse_decoder.decode_chunk(chunk) for event in events: - if event['type'] == 'data': - content = sse_decoder.extract_content(event['data']) + if event["type"] == "data": + content = sse_decoder.extract_content(event["data"]) sse_decoder.add_content(content) - elif event['type'] == 'done': + elif event["type"] == "done": # Log complete content when done full_content = sse_decoder.get_complete_content() if full_content: @@ -1316,14 +1359,17 @@ def buffered_iterator(): full_content = full_content[:2048] + "" "...[truncated]" logger.info( - "response_body={streaming_complete: " \ + "response_body={streaming_complete: " "content='%s', chunks=%d}", - full_content, chunk_count) + full_content, + chunk_count, + ) else: logger.info( - "response_body={streaming_complete: " \ + "response_body={streaming_complete: " "no_content, chunks=%d}", - chunk_count) + chunk_count, + ) return response.body_iterator = iterate_in_threadpool(buffered_iterator()) @@ -1363,9 +1409,11 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(HTTPException) async def http_exception_handler(_: Request, exc: HTTPException): - err = ErrorResponse(message=exc.detail, - type=HTTPStatus(exc.status_code).phrase, - code=exc.status_code) + err = ErrorResponse( + message=exc.detail, + type=HTTPStatus(exc.status_code).phrase, + code=exc.status_code, + ) return JSONResponse(err.model_dump(), status_code=exc.status_code) @app.exception_handler(RequestValidationError) @@ -1379,9 +1427,11 @@ async def validation_exception_handler(_: Request, else: message = exc_str - err = ErrorResponse(message=message, - type=HTTPStatus.BAD_REQUEST.phrase, - code=HTTPStatus.BAD_REQUEST) + err = ErrorResponse( + message=message, + type=HTTPStatus.BAD_REQUEST.phrase, + code=HTTPStatus.BAD_REQUEST, + ) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -1479,7 +1529,9 @@ async def init_app_state( "Using supplied chat template: %s\n" "It is different from official chat template '%s'. " "This discrepancy may lead to performance degradation.", - resolved_chat_template, args.model) + resolved_chat_template, + args.model, + ) state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, @@ -1489,7 +1541,7 @@ async def init_app_state( prompt_adapters=args.prompt_adapters, ) await state.openai_serving_models.init_static_loras() - state.openai_serving_responses = OpenAIServingResponses( + state.openai_serving_responses = (OpenAIServingResponses( engine_client, model_config, state.openai_serving_models, @@ -1505,8 +1557,8 @@ async def init_app_state( enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, - ) if model_config.runner_type == "generate" else None - state.openai_serving_chat = OpenAIServingChat( + ) if model_config.runner_type == "generate" else None) + state.openai_serving_chat = (OpenAIServingChat( engine_client, model_config, state.openai_serving_models, @@ -1523,46 +1575,46 @@ async def init_app_state( enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, - ) if model_config.runner_type == "generate" else None - state.openai_serving_completion = OpenAIServingCompletion( + ) if model_config.runner_type == "generate" else None) + state.openai_serving_completion = (OpenAIServingCompletion( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_force_include_usage=args.enable_force_include_usage, - ) if model_config.runner_type == "generate" else None - state.openai_serving_pooling = OpenAIServingPooling( + ) if model_config.runner_type == "generate" else None) + state.openai_serving_pooling = (OpenAIServingPooling( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if model_config.runner_type == "pooling" else None - state.openai_serving_embedding = OpenAIServingEmbedding( + ) if model_config.runner_type == "pooling" else None) + state.openai_serving_embedding = (OpenAIServingEmbedding( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if model_config.task == "embed" else None - state.openai_serving_classification = ServingClassification( + ) if model_config.task == "embed" else None) + state.openai_serving_classification = (ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.task == "classify" else None + ) if model_config.task == "classify" else None) enable_serving_reranking = (model_config.task == "classify" and getattr( model_config.hf_config, "num_labels", 0) == 1) - state.openai_serving_scores = ServingScores( + state.openai_serving_scores = (ServingScores( engine_client, model_config, state.openai_serving_models, - request_logger=request_logger) if ( - model_config.task == "embed" or enable_serving_reranking) else None + request_logger=request_logger, + ) if (model_config.task == "embed" or enable_serving_reranking) else None) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, @@ -1572,18 +1624,18 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) - state.openai_serving_transcription = OpenAIServingTranscription( + state.openai_serving_transcription = (OpenAIServingTranscription( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.runner_type == "transcription" else None - state.openai_serving_translation = OpenAIServingTranslation( + ) if model_config.runner_type == "transcription" else None) + state.openai_serving_translation = (OpenAIServingTranslation( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.runner_type == "transcription" else None + ) if model_config.runner_type == "transcription" else None) state.task = model_config.task state.enable_server_load_tracking = args.enable_server_load_tracking @@ -1605,14 +1657,14 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() - if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valid_tool_parses: + if (args.enable_auto_tool_choice + and args.tool_call_parser not in valid_tool_parses): raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " f"(chose from {{ {','.join(valid_tool_parses)} }})") valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() - if args.reasoning_parser \ - and args.reasoning_parser not in valid_reasoning_parses: + if (args.reasoning_parser + and args.reasoning_parser not in valid_reasoning_parses): raise KeyError( f"invalid reasoning parser: {args.reasoning_parser} " f"(chose from {{ {','.join(valid_reasoning_parses)} }})") @@ -1648,8 +1700,8 @@ def signal_handler(*_) -> None: addr, port = sock_addr is_ssl = args.ssl_keyfile and args.ssl_certfile - host_part = f"[{addr}]" if is_valid_ipv6_address( - addr) else addr or "0.0.0.0" + host_part = (f"[{addr}]" + if is_valid_ipv6_address(addr) else addr or "0.0.0.0") listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" return listen_address, sock @@ -1676,7 +1728,7 @@ async def run_server_worker(listen_address, # Load logging config for uvicorn if specified log_config = load_log_config(args.log_config_file) if log_config is not None: - uvicorn_kwargs['log_config'] = log_config + uvicorn_kwargs["log_config"] = log_config async with build_async_engine_client(args, client_config) as engine_client: app = build_app(args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index b05efa6fd0f..1097627eea4 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -41,10 +41,10 @@ def __call__( lora_list: list[LoRAModulePath] = [] for item in values: - if item in [None, '']: # Skip if item is None or empty string + if item in [None, ""]: # Skip if item is None or empty string continue - if '=' in item and ',' not in item: # Old format: name=path - name, path = item.split('=') + if "=" in item and "," not in item: # Old format: name=path + name, path = item.split("=") lora_list.append(LoRAModulePath(name, path)) else: # Assume JSON format try: @@ -77,7 +77,7 @@ def __call__( adapter_list: list[PromptAdapterPath] = [] for item in values: - name, path = item.split('=') + name, path = item.split("=") adapter_list.append(PromptAdapterPath(name, path)) setattr(namespace, self.dest, adapter_list) @@ -92,102 +92,128 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--uvicorn-log-level", type=str, default="info", - choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], - help="Log level for uvicorn.") - parser.add_argument("--disable-uvicorn-access-log", - action="store_true", - help="Disable uvicorn access log.") + choices=["debug", "info", "warning", "error", "critical", "trace"], + help="Log level for uvicorn.", + ) + parser.add_argument( + "--disable-uvicorn-access-log", + action="store_true", + help="Disable uvicorn access log.", + ) parser.add_argument("--allow-credentials", action="store_true", help="Allow credentials.") - parser.add_argument("--allowed-origins", - type=json.loads, - default=["*"], - help="Allowed origins.") - parser.add_argument("--allowed-methods", - type=json.loads, - default=["*"], - help="Allowed methods.") - parser.add_argument("--allowed-headers", - type=json.loads, - default=["*"], - help="Allowed headers.") - parser.add_argument("--api-key", - type=optional_type(str), - default=None, - help="If provided, the server will require this key " - "to be presented in the header.") + parser.add_argument( + "--allowed-origins", + type=json.loads, + default=["*"], + help="Allowed origins.", + ) + parser.add_argument( + "--allowed-methods", + type=json.loads, + default=["*"], + help="Allowed methods.", + ) + parser.add_argument( + "--allowed-headers", + type=json.loads, + default=["*"], + help="Allowed headers.", + ) + parser.add_argument( + "--api-key", + type=optional_type(str), + default=None, + help="If provided, the server will require this key " + "to be presented in the header.", + ) parser.add_argument( "--lora-modules", type=optional_type(str), default=None, - nargs='+', + nargs="+", action=LoRAParserAction, help="LoRA module configurations in either 'name=path' format" "or JSON format. " "Example (old format): ``'name=path'`` " "Example (new format): " - "``{\"name\": \"name\", \"path\": \"lora_path\", " - "\"base_model_name\": \"id\"}``") + '``{"name": "name", "path": "lora_path", ' + '"base_model_name": "id"}``', + ) parser.add_argument( "--prompt-adapters", type=optional_type(str), default=None, - nargs='+', + nargs="+", action=PromptAdapterParserAction, help="Prompt adapter configurations in the format name=path. " - "Multiple adapters can be specified.") - parser.add_argument("--chat-template", - type=optional_type(str), - default=None, - help="The file path to the chat template, " - "or the template in single-line form " - "for the specified model.") + "Multiple adapters can be specified.", + ) parser.add_argument( - '--chat-template-content-format', + "--chat-template", + type=optional_type(str), + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model.", + ) + parser.add_argument( + "--chat-template-content-format", type=str, default="auto", choices=get_args(ChatTemplateContentFormatOption), - help='The format to render message content within a chat template.' - '\n\n' + help="The format to render message content within a chat template." + "\n\n" '* "string" will render the content as a string. ' 'Example: ``"Hello World"``\n' '* "openai" will render the content as a list of dictionaries, ' - 'similar to OpenAI schema. ' - 'Example: ``[{"type": "text", "text": "Hello world!"}]``') - parser.add_argument("--response-role", - type=optional_type(str), - default="assistant", - help="The role name to return if " - "``request.add_generation_prompt=true``.") - parser.add_argument("--ssl-keyfile", - type=optional_type(str), - default=None, - help="The file path to the SSL key file.") - parser.add_argument("--ssl-certfile", - type=optional_type(str), - default=None, - help="The file path to the SSL cert file.") - parser.add_argument("--ssl-ca-certs", - type=optional_type(str), - default=None, - help="The CA certificates file.") + "similar to OpenAI schema. " + 'Example: ``[{"type": "text", "text": "Hello world!"}]``', + ) + parser.add_argument( + "--response-role", + type=optional_type(str), + default="assistant", + help="The role name to return if " + "``request.add_generation_prompt=true``.", + ) + parser.add_argument( + "--ssl-keyfile", + type=optional_type(str), + default=None, + help="The file path to the SSL key file.", + ) + parser.add_argument( + "--ssl-certfile", + type=optional_type(str), + default=None, + help="The file path to the SSL cert file.", + ) + parser.add_argument( + "--ssl-ca-certs", + type=optional_type(str), + default=None, + help="The CA certificates file.", + ) parser.add_argument( "--enable-ssl-refresh", action="store_true", default=False, - help="Refresh SSL Context when SSL certificate files change") + help="Refresh SSL Context when SSL certificate files change", + ) parser.add_argument( "--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE), - help="Whether client certificate is required (see stdlib ssl module's)." + help= + "Whether client certificate is required (see stdlib ssl module's).", ) parser.add_argument( "--root-path", type=optional_type(str), default=None, - help="FastAPI root_path when app is behind a path based routing proxy." + help="FastAPI root_path when app is behind a path based routing proxy.", ) parser.add_argument( "--middleware", @@ -200,29 +226,34 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "If a function is provided, vLLM will add it to the server " "using ``@app.middleware('http')``. " "If a class is provided, vLLM will add it to the server " - "using ``app.add_middleware()``. ") + "using ``app.add_middleware()``. ", + ) parser.add_argument( "--return-tokens-as-token-ids", action="store_true", help="When ``--max-logprobs`` is specified, represents single tokens " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.") + "that are not JSON-encodable can be identified.", + ) parser.add_argument( "--disable-frontend-multiprocessing", action="store_true", help="If specified, will run the OpenAI frontend server in the same " - "process as the model serving engine.") + "process as the model serving engine.", + ) parser.add_argument( "--enable-request-id-headers", action="store_true", help="If specified, API server will add X-Request-Id header to " - "responses.") + "responses.", + ) parser.add_argument( "--enable-auto-tool-choice", action="store_true", default=False, help="Enable auto tool choice for supported models. Use " - "``--tool-call-parser`` to specify which parser to use.") + "``--tool-call-parser`` to specify which parser to use.", + ) parser.add_argument( "--expand-tools-even-if-tool-choice-none", action="store_true", @@ -233,7 +264,8 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "This is a transitional option that will be removed in v0.10.0. " "In v0.10.0, tool definitions will always be included regardless of " "tool_choice setting. Use this flag now to test the new behavior " - "before the breaking change.") + "before the breaking change.", + ) valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( @@ -245,7 +277,8 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " - "format. Required for ``--enable-auto-tool-choice``.") + "format. Required for ``--enable-auto-tool-choice``.", + ) parser.add_argument( "--tool-parser-plugin", @@ -254,7 +287,8 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Special the tool parser plugin write to parse the model-generated tool" " into OpenAI API format, the name register in this plugin can be used " - "in ``--tool-call-parser``.") + "in ``--tool-call-parser``.", + ) parser.add_argument( "--log-config-file", @@ -265,43 +299,47 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser = AsyncEngineArgs.add_cli_args(parser) - parser.add_argument('--max-log-len', - type=int, - default=None, - help='Max number of prompt characters or prompt ' - 'ID numbers being printed in log.' - ' The default of None means unlimited.') + parser.add_argument( + "--max-log-len", + type=int, + default=None, + help="Max number of prompt characters or prompt " + "ID numbers being printed in log." + " The default of None means unlimited.", + ) parser.add_argument( "--disable-fastapi-docs", - action='store_true', + action="store_true", default=False, - help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint." + help= + "Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint.", ) parser.add_argument( "--enable-prompt-tokens-details", - action='store_true', + action="store_true", default=False, - help="If set to True, enable prompt_tokens_details in usage.") + help="If set to True, enable prompt_tokens_details in usage.", + ) parser.add_argument( "--enable-force-include-usage", - action='store_true', + action="store_true", default=False, - help="If set to True, including usage on every request.") + help="If set to True, including usage on every request.", + ) parser.add_argument( "--enable-server-load-tracking", - action='store_true', + action="store_true", default=False, help= - "If set to True, enable tracking server_load_metrics in the app state." + "If set to True, enable tracking server_load_metrics in the app state.", ) parser.add_argument( "--enable-log-outputs", - action='store_true', + action="store_true", default=False, - help= - "If set to True, enable logging of model outputs (generations) " - "in addition to the input logging that is enabled by default." + help="If set to True, enable logging of model outputs (generations) " + "in addition to the input logging that is enabled by default.", ) return parser @@ -317,8 +355,8 @@ def validate_parsed_serve_args(args: argparse.Namespace): # Enable auto tool needs a tool call parser to be valid if args.enable_auto_tool_choice and not args.tool_call_parser: - raise TypeError("Error: --enable-auto-tool-choice requires " - "--tool-call-parser") + raise TypeError( + "Error: --enable-auto-tool-choice requires --tool-call-parser") if args.enable_prompt_embeds and args.enable_prompt_adapter: raise ValueError( "Cannot use prompt embeds and prompt adapter at the same time.") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ba75dbc6b65..b910dcbb753 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -69,12 +69,14 @@ def __init__( enable_force_include_usage: bool = False, enable_log_outputs: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage, + ) self.response_role = response_role self.chat_template = chat_template @@ -85,7 +87,7 @@ def __init__( self.enable_auto_tools: bool = enable_auto_tools if self.enable_auto_tools: logger.info( - "\"auto\" tool choice has been enabled please note that while" + '"auto" tool choice has been enabled please note that while' " the parallel_tool_calls client option is preset for " "compatibility reasons, it will be ignored.") @@ -103,8 +105,8 @@ def __init__( self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: try: - if (tool_parser == "pythonic" and - model_config.model.startswith("meta-llama/Llama-3.2")): + if tool_parser == "pythonic" and model_config.model.startswith( + "meta-llama/Llama-3.2"): logger.warning( "Llama3.2 models may struggle to emit valid pythonic" " tool calls") @@ -124,8 +126,11 @@ def __init__( if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info("Using default chat sampling params from %s: %s", - source, self.default_sampling_params) + logger.info( + "Using default chat sampling params from %s: %s", + source, + self.default_sampling_params, + ) async def create_chat_completion( self, @@ -177,7 +182,7 @@ async def create_chat_completion( # for hf tokenizers, "auto" tools requires # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( - "\"auto\" tool choice requires " + '"auto" tool choice requires ' "--enable-auto-tool-choice and --tool-call-parser to be set" ) @@ -224,8 +229,9 @@ async def create_chat_completion( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") - request_id = "chatcmpl-" \ - f"{self._base_request_id(raw_request, request.request_id)}" + request_id = ( + f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" + ) request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: @@ -244,21 +250,26 @@ async def create_chat_completion( max_model_len=self.max_model_len, request=request, input_length=len(engine_prompt["prompt_token_ids"]), - default_sampling_params=self.default_sampling_params) + default_sampling_params=self.default_sampling_params, + ) if request.use_beam_search: sampling_params = request.to_beam_search_params( max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( - max_tokens, self.model_config.logits_processor_pattern, - self.default_sampling_params) + max_tokens, + self.model_config.logits_processor_pattern, + self.default_sampling_params, + ) - self._log_inputs(request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + self._log_inputs( + request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) @@ -287,7 +298,7 @@ async def create_chat_completion( return self.create_error_response(str(e)) assert len(generators) == 1 - result_generator, = generators + (result_generator, ) = generators # Streaming response if request.stream: @@ -299,12 +310,19 @@ async def create_chat_completion( conversation, tokenizer, request_metadata, - enable_force_include_usage=self.enable_force_include_usage) + enable_force_include_usage=self.enable_force_include_usage, + ) try: return await self.chat_completion_full_generator( - request, result_generator, request_id, model_name, - conversation, tokenizer, request_metadata) + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + ) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -315,7 +333,7 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return request.messages[-1]["role"] @staticmethod - def _bracket_level(s: str, opening='{', closing='}') -> int: + def _bracket_level(s: str, opening="{", closing="}") -> int: """ Calculate the current level of nested brackets in a given string. """ @@ -339,10 +357,10 @@ def _filter_delta_text(delta_text: str, bracket_level = OpenAIServingChat._bracket_level(previous_text) updated_delta, passed_zero = "", False for c in delta_text: - if c == '{': + if c == "{": bracket_level += 1 passed_zero = bracket_level == 0 - elif c == '}': + elif c == "}": bracket_level -= 1 passed_zero = bracket_level == 0 @@ -350,7 +368,7 @@ def _filter_delta_text(delta_text: str, updated_delta += c else: # if a comma is reached at level 0 we can stop - if c == ',': + if c == ",": break return updated_delta, passed_zero @@ -367,7 +385,7 @@ def extract_tool_call_required_streaming( try: obj = partial_json_parser.loads(current_text) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") obj = None # check if the current text is a valid array @@ -406,12 +424,15 @@ def extract_tool_call_required_streaming( function_name_returned = True delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(id=random_tool_call_id(), - function=DeltaFunctionCall( - name=current_tool_call["name"], - arguments=arguments), - index=len(obj) - 1, - type="function") + DeltaToolCall( + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=current_tool_call["name"], + arguments=arguments, + ), + index=len(obj) - 1, + type="function", + ) ]) else: @@ -425,8 +446,10 @@ def extract_tool_call_required_streaming( # OpenAI API returns None # instead of name every time name=None, - arguments=delta_text), - index=len(obj) - 1) + arguments=delta_text, + ), + index=len(obj) - 1, + ) ]) else: delta_message = None @@ -509,10 +532,10 @@ async def chat_completion_stream_generator( stream_options = request.stream_options if stream_options: - include_usage = stream_options.include_usage \ - or enable_force_include_usage - include_continuous_usage = include_usage and \ - stream_options.continuous_usage_stats + include_usage = (stream_options.include_usage + or enable_force_include_usage) + include_continuous_usage = (include_usage and + stream_options.continuous_usage_stats) else: include_usage, include_continuous_usage = False, False @@ -542,20 +565,23 @@ async def chat_completion_stream_generator( content="", ), logprobs=None, - finish_reason=None) + finish_reason=None, + ) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + ) # if continuous usage stats are requested, add it if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=num_prompt_tokens) + total_tokens=num_prompt_tokens, + ) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" @@ -564,8 +590,8 @@ async def chat_completion_stream_generator( # last message if request.echo: last_msg_content: Union[str, list[dict[str, str]]] = "" - if conversation and "content" in conversation[ - -1] and conversation[-1].get("role") == role: + if (conversation and "content" in conversation[-1] + and conversation[-1].get("role") == role): last_msg_content = conversation[-1]["content"] or "" if last_msg_content: @@ -576,18 +602,21 @@ async def chat_completion_stream_generator( delta=DeltaMessage( content=last_msg_content), logprobs=None, - finish_reason=None)) + finish_reason=None, + )) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + ) if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=num_prompt_tokens) + total_tokens=num_prompt_tokens, + ) data = chunk.model_dump_json( exclude_unset=True) @@ -617,8 +646,8 @@ async def chat_completion_stream_generator( delta_text = output.text - if not delta_text and not output.token_ids and \ - not previous_num_tokens[i]: + if (not delta_text and not output.token_ids + and not previous_num_tokens[i]): # Chunked prefill case, don't return empty chunks continue @@ -640,16 +669,14 @@ async def chat_completion_stream_generator( and not reasoning_parser.is_reasoning_end( previous_token_ids)): assert reasoning_parser is not None - delta_message = ( - reasoning_parser. - extract_reasoning_content_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - )) + delta_message = reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) # When encountering think end id in delta_token_ids, # process the `content`. Only keep 'content', # remove 'reasoning_content' @@ -671,15 +698,18 @@ async def chat_completion_stream_generator( delta_tool_call = DeltaToolCall( function=DeltaFunctionCall( arguments=delta_text), - index=i) + index=i, + ) else: delta_tool_call = DeltaToolCall( id=random_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, - arguments=delta_text), - index=i) + arguments=delta_text, + ), + index=i, + ) function_name_returned[i] = True delta_message = DeltaMessage(tool_calls=[ @@ -693,11 +723,9 @@ async def chat_completion_stream_generator( fn_name_returned = function_name_returned[i] if self.reasoning_parser: - _, content = \ + _, content = ( reasoning_parser.extract_reasoning_content( - current_text, - request - ) + current_text, request)) else: content = current_text delta_message, function_name_returned[i] = ( @@ -705,7 +733,8 @@ async def chat_completion_stream_generator( previous_text=previous_text, current_text=content, delta_text=delta_text, - function_name_returned=fn_name_returned)) + function_name_returned=fn_name_returned, + )) # update the previous values for the next iteration previous_texts[i] = current_text @@ -718,24 +747,22 @@ async def chat_completion_stream_generator( assert added_content_delta_arr is not None assert reasoning_end_arr is not None if not reasoning_end_arr[i]: - delta_message = ( - reasoning_parser. - extract_reasoning_content_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - )) + delta_message = reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, # set reasoning status to end. # Remove the text and token ids related # to 'reasoning_content'. - if res.prompt_token_ids and \ - reasoning_parser.is_reasoning_end( - list(res.prompt_token_ids)): + if (res.prompt_token_ids + and reasoning_parser.is_reasoning_end( + list(res.prompt_token_ids))): reasoning_end_arr[i] = True current_token_ids = list(output.token_ids) if delta_message and delta_message.content: @@ -750,9 +777,9 @@ async def chat_completion_stream_generator( if reasoning_parser.is_reasoning_end( list(output.token_ids)): reasoning_end_arr[i] = True - current_token_ids = \ + current_token_ids = ( reasoning_parser.extract_content_ids( - list(output.token_ids)) + list(output.token_ids))) if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None @@ -780,7 +807,8 @@ async def chat_completion_stream_generator( previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, - request=request)) + request=request, + )) # when only tool calls elif tool_choice_auto: assert tool_parser is not None @@ -792,18 +820,18 @@ async def chat_completion_stream_generator( previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, delta_token_ids=output.token_ids, - request=request)) + request=request, + )) # when only reasoning elif self.reasoning_parser: - delta_message = (reasoning_parser. - extract_reasoning_content_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - )) + delta_message = reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) @@ -830,12 +858,13 @@ async def chat_completion_stream_generator( delta_content = "" if delta_message.content: delta_content = delta_message.content - elif (delta_message.tool_calls and - delta_message.tool_calls[0].function and + elif (delta_message.tool_calls + and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments): - func_args = delta_message.tool_calls[0].function.arguments + func_args = delta_message.tool_calls[ + 0].function.arguments delta_content = func_args - + if delta_content: self.request_logger.log_outputs( request_id=request_id, @@ -852,7 +881,8 @@ async def chat_completion_stream_generator( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=None) + finish_reason=None, + ) # if the model is finished generating else: @@ -862,21 +892,24 @@ async def chat_completion_stream_generator( # only happens if we are NOT using guided decoding auto_tools_called = False if tool_parser: - auto_tools_called = len( - tool_parser.prev_tool_call_arr) > 0 - index = len(tool_parser.prev_tool_call_arr - ) - 1 if auto_tools_called else 0 + auto_tools_called = (len( + tool_parser.prev_tool_call_arr) > 0) + index = (len(tool_parser.prev_tool_call_arr) - + 1 if auto_tools_called else 0) else: index = 0 - if self._should_check_for_unstreamed_tool_arg_tokens( - delta_message, output) and tool_parser: + if (self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser): latest_delta_len = 0 - if ((isinstance( + if (isinstance( delta_message.tool_calls[0].function, - DeltaFunctionCall)) and isinstance( - delta_message.tool_calls[0].function. - arguments, str)): + DeltaFunctionCall, + )) and isinstance( + delta_message.tool_calls[0].function. + arguments, + str, + ): latest_delta_len = len( delta_message.tool_calls[0].function. arguments) @@ -886,13 +919,14 @@ async def chat_completion_stream_generator( expected_call = json.dumps( tool_parser.prev_tool_call_arr[index].get( "arguments", {}), - ensure_ascii=False) + ensure_ascii=False, + ) # get what we've streamed so far for arguments # for the current tool actual_call = tool_parser.streamed_args_for_tool[ index] - if (latest_delta_len > 0): + if latest_delta_len > 0: actual_call = actual_call[:-latest_delta_len] # check to see if there's anything left to stream @@ -900,10 +934,12 @@ async def chat_completion_stream_generator( actual_call, "", 1) # set that as a delta message delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(index=index, - function=DeltaFunctionCall( - arguments=remaining_call). - model_dump(exclude_none=True)) + DeltaToolCall( + index=index, + function=DeltaFunctionCall( + arguments=remaining_call).model_dump( + exclude_none=True), + ) ]) # Send the finish response for each request.n only once @@ -913,7 +949,8 @@ async def chat_completion_stream_generator( logprobs=logprobs, finish_reason=output.finish_reason if not auto_tools_called else "tool_calls", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + ) finish_reason_sent[i] = True @@ -922,7 +959,8 @@ async def chat_completion_stream_generator( object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + ) # handle usage stats if requested & if continuous if include_continuous_usage: @@ -940,10 +978,11 @@ async def chat_completion_stream_generator( # is sent, send the usage if include_usage: completion_tokens = sum(previous_num_tokens) - final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens) + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) if self.enable_prompt_tokens_details and num_cached_tokens: final_usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=num_cached_tokens) @@ -954,9 +993,10 @@ async def chat_completion_stream_generator( created=created_time, choices=[], model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) + usage=final_usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices @@ -964,7 +1004,8 @@ async def chat_completion_stream_generator( request_metadata.final_usage_info = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens, - total_tokens=num_prompt_tokens + num_completion_tokens) + total_tokens=num_prompt_tokens + num_completion_tokens, + ) # Log complete streaming response if output logging is enabled if self.enable_log_outputs and self.request_logger: @@ -972,7 +1013,8 @@ async def chat_completion_stream_generator( # For now, we'll log the completion tokens count as final output self.request_logger.log_outputs( request_id=request_id, - outputs=f"", + outputs= + f"", output_token_ids=None, finish_reason="streaming_complete", is_streaming=True, @@ -997,7 +1039,6 @@ async def chat_completion_full_generator( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: - created_time = int(time.time()) final_res: Optional[RequestOutput] = None @@ -1049,20 +1090,21 @@ async def chat_completion_full_generator( # if auto tools are not enabled, and a named tool choice using # outlines is not being used - if (not self.enable_auto_tools or not self.tool_parser) and \ - (not isinstance(request.tool_choice, - ChatCompletionNamedToolChoiceParam - ) and request.tool_choice != "required"): - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + if (not self.enable_auto_tools or not self.tool_parser) and ( + not isinstance(request.tool_choice, + ChatCompletionNamedToolChoiceParam) + and request.tool_choice != "required"): + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + ) # if the request uses tools and specified a tool choice - elif request.tool_choice and type( - request.tool_choice) is ChatCompletionNamedToolChoiceParam: - - tool_call_class = MistralToolCall if isinstance( - tokenizer, MistralTokenizer) else ToolCall + elif (request.tool_choice and type(request.tool_choice) + is ChatCompletionNamedToolChoiceParam): + tool_call_class = (MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall) message = ChatMessage( role=role, reasoning_content=reasoning_content, @@ -1070,12 +1112,14 @@ async def chat_completion_full_generator( tool_calls=[ tool_call_class(function=FunctionCall( name=request.tool_choice.function.name, - arguments=content)) - ]) + arguments=content, + )) + ], + ) elif request.tool_choice and request.tool_choice == "required": - tool_call_class = MistralToolCall if isinstance( - tokenizer, MistralTokenizer) else ToolCall + tool_call_class = (MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall) # the fields of FunctionDefinition are a superset of the # tool call outputs and can be used for parsing @@ -1090,24 +1134,25 @@ async def chat_completion_full_generator( tool_call_class(function=FunctionCall( name=tool_call.name, arguments=json.dumps(tool_call.parameters, - ensure_ascii=False))) - for tool_call in tool_calls - ]) + ensure_ascii=False), + )) for tool_call in tool_calls + ], + ) # if the request doesn't use tool choice # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": - - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + ) # handle when there are tools and tool choice is auto - elif request.tools and ( - request.tool_choice == "auto" - or request.tool_choice is None) and self.enable_auto_tools \ - and self.tool_parser: - + elif ( + request.tools and + (request.tool_choice == "auto" or request.tool_choice is None) + and self.enable_auto_tools and self.tool_parser): try: tool_parser = self.tool_parser(tokenizer) except RuntimeError as e: @@ -1121,17 +1166,21 @@ async def chat_completion_full_generator( # call. The same is not true for named function calls auto_tools_called = tool_call_info.tools_called if tool_call_info.tools_called: - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=tool_call_info.content, - tool_calls=tool_call_info.tool_calls) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls, + ) else: # FOR NOW make it a chat message; we will have to detect # the type to make it later. - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + ) # undetermined case that is still important to handle else: @@ -1139,9 +1188,11 @@ async def chat_completion_full_generator( "Error in chat_completion_full_generator - cannot determine" " if tools should be extracted. Returning a standard chat " "completion.") - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + ) choice_data = ChatCompletionResponseChoice( index=output.index, @@ -1149,16 +1200,17 @@ async def chat_completion_full_generator( logprobs=logprobs, finish_reason="tool_calls" if auto_tools_called else output.finish_reason if output.finish_reason else "stop", - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + ) choices.append(choice_data) if request.echo: last_msg_content: Union[str, list[dict[str, str]]] = "" - if conversation and "content" in conversation[-1] and conversation[ - -1].get("role") == role: + if (conversation and "content" in conversation[-1] + and conversation[-1].get("role") == role): last_msg_content = conversation[-1]["content"] or "" if isinstance(last_msg_content, list): - last_msg_content = "\n".join(msg['text'] + last_msg_content = "\n".join(msg["text"] for msg in last_msg_content) for choice in choices: @@ -1172,10 +1224,11 @@ async def chat_completion_full_generator( num_prompt_tokens += len(final_res.encoder_prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + - num_generated_tokens) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) if self.enable_prompt_tokens_details and final_res.num_cached_tokens: usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=final_res.num_cached_tokens) @@ -1202,18 +1255,21 @@ async def chat_completion_full_generator( # For tool calls, log the function name and arguments tool_call_descriptions = [] for tool_call in choice.message.tool_calls: - if hasattr(tool_call.function, 'name') and hasattr(tool_call.function, 'arguments'): + if hasattr(tool_call.function, "name") and hasattr( + tool_call.function, "arguments"): tool_call_descriptions.append( - f"{tool_call.function.name}({tool_call.function.arguments})") - tool_calls_str = ', '.join(tool_call_descriptions) + f"{tool_call.function.name}({tool_call.function.arguments})" + ) + tool_calls_str = ", ".join(tool_call_descriptions) output_text = f"[tool_calls: {tool_calls_str}]" - + if output_text: # Get the corresponding output token IDs output_token_ids = None if choice.index < len(final_res.outputs): - output_token_ids = final_res.outputs[choice.index].token_ids - + output_token_ids = final_res.outputs[ + choice.index].token_ids + self.request_logger.log_outputs( request_id=request_id, outputs=output_text, @@ -1226,19 +1282,23 @@ async def chat_completion_full_generator( return response def _get_top_logprobs( - self, logprobs: dict[int, Logprob], top_logprobs: Optional[int], - tokenizer: AnyTokenizer, - should_return_as_token_id: bool) -> list[ChatCompletionLogProb]: + self, + logprobs: dict[int, Logprob], + top_logprobs: Optional[int], + tokenizer: AnyTokenizer, + should_return_as_token_id: bool, + ) -> list[ChatCompletionLogProb]: return [ - ChatCompletionLogProb(token=(token := self._get_decoded_token( - p[1], - p[0], - tokenizer, - return_as_token_id=should_return_as_token_id)), - logprob=max(p[1].logprob, -9999.0), - bytes=list( - token.encode("utf-8", errors="replace"))) - for i, p in enumerate(logprobs.items()) + ChatCompletionLogProb( + token=(token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=should_return_as_token_id, + )), + logprob=max(p[1].logprob, -9999.0), + bytes=list(token.encode("utf-8", errors="replace")), + ) for i, p in enumerate(logprobs.items()) if top_logprobs and i < top_logprobs ] @@ -1253,12 +1313,13 @@ def _create_chat_logprobs( """Create OpenAI-style logprobs.""" logprobs_content: list[ChatCompletionLogProbsContent] = [] - should_return_as_token_id = return_as_token_id if \ - return_as_token_id is not None else self.return_tokens_as_token_ids + should_return_as_token_id = (return_as_token_id + if return_as_token_id is not None else + self.return_tokens_as_token_ids) for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] - if step_top_logprobs is None or step_top_logprobs.get( - token_id) is None: + if (step_top_logprobs is None + or step_top_logprobs.get(token_id) is None): token = tokenizer.decode(token_id) if should_return_as_token_id: token = f"token_id:{token_id}" @@ -1284,8 +1345,11 @@ def _create_chat_logprobs( bytes=None if step_decoded is None else list( step_decoded.encode("utf-8", errors="replace")), top_logprobs=self._get_top_logprobs( - step_top_logprobs, num_output_top_logprobs, - tokenizer, should_return_as_token_id), + step_top_logprobs, + num_output_top_logprobs, + tokenizer, + should_return_as_token_id, + ), )) return ChatCompletionLogProbs(content=logprobs_content) @@ -1301,7 +1365,7 @@ def _should_stream_with_auto_tool_parsing(self, choice field indicates that "auto" tool choice should be used. """ return (request.tools and self.tool_parser and self.enable_auto_tools - and request.tool_choice in ['auto', None]) + and request.tool_choice in ["auto", None]) def _should_check_for_unstreamed_tool_arg_tokens( self, diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 31f55a8868a..757554e848c 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -89,8 +89,11 @@ def __init__( if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info("Using default chat sampling params from %s: %s", - source, self.default_sampling_params) + logger.info( + "Using default chat sampling params from %s: %s", + source, + self.default_sampling_params, + ) # HACK(woosuk): This is a hack. We should use a better store. # FIXME: This causes a memory leak since we never remove responses @@ -169,11 +172,13 @@ async def create_responses( sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params) - self._log_inputs(request.request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + self._log_inputs( + request.request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) @@ -193,7 +198,7 @@ async def create_responses( return self.create_error_response(str(e)) assert len(generators) == 1 - result_generator, = generators + (result_generator, ) = generators # Store the input messages. if request.store: @@ -344,7 +349,7 @@ async def responses_full_generator( output_text = content elif reasoning_content: output_text = f"[reasoning: {reasoning_content}]" - + if output_text: self.request_logger.log_outputs( request_id=request.request_id, @@ -460,7 +465,7 @@ async def cancel_responses( response.status = "cancelled" # Abort the request. - if (task := self.background_tasks.get(response_id)): + if task := self.background_tasks.get(response_id): task.cancel() try: await task From fb13841bd220b13e639b06f4330e53571cfb3b6a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Fri, 11 Jul 2025 10:50:35 +0400 Subject: [PATCH 06/13] Shortened lines to meet rule of line length < 80 characters Signed-off-by: Adrian Garcia --- vllm/entrypoints/openai/serving_chat.py | 53 ++++++++++++++----------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b910dcbb753..7d7fae1d498 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -669,14 +669,16 @@ async def chat_completion_stream_generator( and not reasoning_parser.is_reasoning_end( previous_token_ids)): assert reasoning_parser is not None - delta_message = reasoning_parser.extract_reasoning_content_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - ) + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) # When encountering think end id in delta_token_ids, # process the `content`. Only keep 'content', # remove 'reasoning_content' @@ -747,14 +749,16 @@ async def chat_completion_stream_generator( assert added_content_delta_arr is not None assert reasoning_end_arr is not None if not reasoning_end_arr[i]: - delta_message = reasoning_parser.extract_reasoning_content_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - ) + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, # set reasoning status to end. @@ -824,14 +828,15 @@ async def chat_completion_stream_generator( )) # when only reasoning elif self.reasoning_parser: - delta_message = reasoning_parser.extract_reasoning_content_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - ) + delta_message = (reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) From 42085ad6c0d60b143744fdb1789aa9bb02a91117 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Fri, 11 Jul 2025 18:32:36 +0400 Subject: [PATCH 07/13] tests: fix incorrect assertion indices in log_outputs tests The assertions in log_outputs test methods were checking wrong argument indices from mocked logger calls, causing tests to validate incorrect behavior and pass incorrectly. The logger.info call signature is: logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) Fixed argument index assertions in all affected test methods: - test_request_logger_log_outputs - test_request_logger_log_outputs_streaming_delta - test_request_logger_log_outputs_streaming_complete - test_request_logger_log_outputs_with_truncation - test_request_logger_log_outputs_none_values - test_request_logger_log_outputs_empty_output - test_request_logger_log_outputs_integration Tests now correctly validate outputs at index 3, output_token_ids at index 4, and finish_reason at index 5, instead of the previous incorrect indices 1, 2, and 3 respectively. Signed-off-by: Adrian Garcia --- tests/test_logger.py | 76 +++++++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index b8d0a5c2ffd..cd66198937b 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -281,11 +281,13 @@ def test_request_logger_log_outputs(): ) mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert "Generated response test-123" in call_args[0] - assert "Hello, world!" in call_args[1] - assert call_args[2] == [1, 2, 3, 4] - assert call_args[3] == "stop" + call_args = mock_logger.info.call_args.args + # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-123" + assert call_args[3] == "Hello, world!" + assert call_args[4] == [1, 2, 3, 4] + assert call_args[5] == "stop" def test_request_logger_log_outputs_streaming_delta(): @@ -306,11 +308,14 @@ def test_request_logger_log_outputs_streaming_delta(): ) mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert "Generated response test-456 (streaming delta)" in call_args[0] - assert call_args[1] == "Hello" - assert call_args[2] == [1] - assert call_args[3] is None + call_args = mock_logger.info.call_args.args + # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-456" + assert call_args[2] == " (streaming delta)" + assert call_args[3] == "Hello" + assert call_args[4] == [1] + assert call_args[5] is None def test_request_logger_log_outputs_streaming_complete(): @@ -331,12 +336,14 @@ def test_request_logger_log_outputs_streaming_complete(): ) mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert ("Generated response test-789 (streaming complete)" - in call_args[0]) - assert call_args[1] == "Complete response" - assert call_args[2] == [1, 2, 3] - assert call_args[3] == "length" + call_args = mock_logger.info.call_args.args + # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-789" + assert call_args[2] == " (streaming complete)" + assert call_args[3] == "Complete response" + assert call_args[4] == [1, 2, 3] + assert call_args[5] == "length" def test_request_logger_log_outputs_with_truncation(): @@ -364,12 +371,12 @@ def test_request_logger_log_outputs_with_truncation(): call_args = mock_logger.info.call_args # Check that output was truncated to first 10 characters - logged_output = call_args[0][1] + logged_output = call_args[0][3] assert logged_output == "This is a " assert len(logged_output) == 10 # Check that token IDs were truncated to first 10 tokens - logged_token_ids = call_args[0][2] + logged_token_ids = call_args[0][4] assert logged_token_ids == list(range(10)) assert len(logged_token_ids) == 10 @@ -392,11 +399,13 @@ def test_request_logger_log_outputs_none_values(): ) mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert "Generated response test-none" in call_args[0] - assert call_args[1] == "Test output" - assert call_args[2] is None - assert call_args[3] == "stop" + call_args = mock_logger.info.call_args.args + # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-none" + assert call_args[3] == "Test output" + assert call_args[4] is None + assert call_args[5] == "stop" def test_request_logger_log_outputs_empty_output(): @@ -417,11 +426,13 @@ def test_request_logger_log_outputs_empty_output(): ) mock_logger.info.assert_called_once() - call_args = mock_logger.info.call_args[0] - assert "Generated response test-empty" in call_args[0] - assert call_args[1] == "" - assert call_args[2] == [] - assert call_args[3] == "stop" + call_args = mock_logger.info.call_args.args + # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) + assert "Generated response %s%s" in call_args[0] + assert call_args[1] == "test-empty" + assert call_args[3] == "" + assert call_args[4] == [] + assert call_args[5] == "stop" def test_request_logger_log_outputs_integration(): @@ -458,5 +469,10 @@ def test_request_logger_log_outputs_integration(): input_call = mock_logger.info.call_args_list[0][0] output_call = mock_logger.info.call_args_list[1][0] - assert "Received request test-integration" in input_call[0] - assert "Generated response test-integration" in output_call[0] + # Check input call: logger.info(format_string, request_id, prompt, params, ...) + assert "Received request %s" in input_call[0] + assert input_call[1] == "test-integration" + + # Check output call: logger.info(format_string, request_id, stream_info, outputs, ...) + assert "Generated response %s%s" in output_call[0] + assert output_call[1] == "test-integration" From 6185d666652d9b83b5e611d6aa2002ee34ca8835 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Sat, 12 Jul 2025 00:14:53 +0400 Subject: [PATCH 08/13] logger: Include full text in streaming response logs Previously, the log message for a completed streaming response only included the number of generated tokens, which limited debugging and auditing capabilities. This change: - Modifies the streaming response logging to include the full concatenated text instead of just token counts - Adds test coverage to verify the full text logging behavior - Ensures all logger.info call argument indices are correct in tests The change improves the utility of logs for debugging and auditing by providing complete output records. Signed-off-by: Adrian Garcia --- tests/test_logger.py | 33 +++++++++++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 36 +++++++++++++++---------- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index cd66198937b..2c748d519b8 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -476,3 +476,36 @@ def test_request_logger_log_outputs_integration(): # Check output call: logger.info(format_string, request_id, stream_info, outputs, ...) assert "Generated response %s%s" in output_call[0] assert output_call[1] == "test-integration" + + +def test_streaming_complete_logs_full_text_content(): + """Test that streaming complete logging includes full accumulated text, not just token count.""" + mock_logger = MagicMock() + + with patch("vllm.entrypoints.logger.logger", mock_logger): + request_logger = RequestLogger(max_log_len=None) + + # Test with actual content instead of token count format + full_response = "This is a complete response from streaming" + request_logger.log_outputs( + request_id="test-streaming-full-text", + outputs=full_response, + output_token_ids=None, + finish_reason="streaming_complete", + is_streaming=True, + delta=False, + ) + + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args.args + + # Verify the logged output is the full text, not a token count format + logged_output = call_args[3] + assert logged_output == full_response + assert "tokens>" not in logged_output # Ensure it's not the old token count format + assert "streaming_complete" not in logged_output # Ensure it's not the fallback format + + # Verify other parameters + assert call_args[1] == "test-streaming-full-text" + assert call_args[2] == " (streaming complete)" + assert call_args[5] == "streaming_complete" diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 7d7fae1d498..816145cdb53 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -491,20 +491,21 @@ async def chat_completion_stream_generator( all_previous_token_ids: Optional[list[list[int]]] function_name_returned = [False] * num_choices + # Always track previous_texts for comprehensive output logging + previous_texts = [""] * num_choices + # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. if tool_choice_auto or self.reasoning_parser: # These are only required in "auto" tool choice case - previous_texts = [""] * num_choices all_previous_token_ids = [[]] * num_choices # For reasoning parser and tool call all enabled added_content_delta_arr = [False] * num_choices reasoning_end_arr = [False] * num_choices elif request.tool_choice == "required": - previous_texts = [""] * num_choices all_previous_token_ids = None else: - previous_texts, all_previous_token_ids = None, None + all_previous_token_ids = None try: if self.reasoning_parser: @@ -847,6 +848,10 @@ async def chat_completion_stream_generator( assert all_previous_token_ids is not None previous_texts[i] = current_text all_previous_token_ids[i] = current_token_ids + else: + # Update previous_texts for comprehensive logging even in simple content case + assert previous_texts is not None + previous_texts[i] += delta_text # set the previous values for the next iteration previous_num_tokens[i] += len(output.token_ids) @@ -1014,17 +1019,20 @@ async def chat_completion_stream_generator( # Log complete streaming response if output logging is enabled if self.enable_log_outputs and self.request_logger: - # Collect all generated text from the SSE decoder if available - # For now, we'll log the completion tokens count as final output - self.request_logger.log_outputs( - request_id=request_id, - outputs= - f"", - output_token_ids=None, - finish_reason="streaming_complete", - is_streaming=True, - delta=False, - ) + # Log the complete response for each choice + for i in range(num_choices): + full_text = (previous_texts[i] if previous_texts + and i < len(previous_texts) else + f"" + ) + self.request_logger.log_outputs( + request_id=request_id, + outputs=full_text, + output_token_ids=None, # Consider also logging all token IDs + finish_reason="streaming_complete", + is_streaming=True, + delta=False, + ) except Exception as e: # TODO: Use a vllm-specific Validation Error From e864415c90ec96ca63a5e54530fb90691f3a97f2 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Sat, 12 Jul 2025 00:40:57 +0400 Subject: [PATCH 09/13] openai/serving_chat: log all tool-call arguments in streaming deltas MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously only the first tool call’s arguments were captured when logging streaming delta content, which could miss information if multiple tool calls were present in a single delta. The extraction logic now concatenates the arguments from *all* tool calls ensuring complete logging. Additional changes: * Updated unit tests to remain within Ruff line-length limits (E501). * Auto-formatted touched files via project pre-commit hooks. Signed-off-by: Adrian Garcia --- tests/test_logger.py | 6 +++--- vllm/entrypoints/openai/serving_chat.py | 25 +++++++++++++------------ 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index 2c748d519b8..bbc81a00a3b 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -472,7 +472,7 @@ def test_request_logger_log_outputs_integration(): # Check input call: logger.info(format_string, request_id, prompt, params, ...) assert "Received request %s" in input_call[0] assert input_call[1] == "test-integration" - + # Check output call: logger.info(format_string, request_id, stream_info, outputs, ...) assert "Generated response %s%s" in output_call[0] assert output_call[1] == "test-integration" @@ -498,13 +498,13 @@ def test_streaming_complete_logs_full_text_content(): mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args.args - + # Verify the logged output is the full text, not a token count format logged_output = call_args[3] assert logged_output == full_response assert "tokens>" not in logged_output # Ensure it's not the old token count format assert "streaming_complete" not in logged_output # Ensure it's not the fallback format - + # Verify other parameters assert call_args[1] == "test-streaming-full-text" assert call_args[2] == " (streaming complete)" diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 816145cdb53..3f4f4fd0479 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -493,7 +493,7 @@ async def chat_completion_stream_generator( # Always track previous_texts for comprehensive output logging previous_texts = [""] * num_choices - + # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. if tool_choice_auto or self.reasoning_parser: @@ -868,12 +868,11 @@ async def chat_completion_stream_generator( delta_content = "" if delta_message.content: delta_content = delta_message.content - elif (delta_message.tool_calls - and delta_message.tool_calls[0].function and - delta_message.tool_calls[0].function.arguments): - func_args = delta_message.tool_calls[ - 0].function.arguments - delta_content = func_args + elif delta_message.tool_calls: + delta_content = "".join( + tc.function.arguments + for tc in delta_message.tool_calls + if tc.function and tc.function.arguments) if delta_content: self.request_logger.log_outputs( @@ -1021,14 +1020,16 @@ async def chat_completion_stream_generator( if self.enable_log_outputs and self.request_logger: # Log the complete response for each choice for i in range(num_choices): - full_text = (previous_texts[i] if previous_texts - and i < len(previous_texts) else - f"" - ) + full_text = ( + previous_texts[i] + if previous_texts and i < len(previous_texts) else + f"" + ) self.request_logger.log_outputs( request_id=request_id, outputs=full_text, - output_token_ids=None, # Consider also logging all token IDs + output_token_ids= + None, # Consider also logging all token IDs finish_reason="streaming_complete", is_streaming=True, delta=False, From 1ed2689ea2b8721860ab08bb3b8bf08ca92d37f5 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Sat, 12 Jul 2025 12:50:39 +0400 Subject: [PATCH 10/13] Removed comments that broke the line length constraint: Signed-off-by: Adrian Garcia --- tests/test_logger.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index bbc81a00a3b..0f9eb7b5785 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -282,7 +282,6 @@ def test_request_logger_log_outputs(): mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args.args - # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) assert "Generated response %s%s" in call_args[0] assert call_args[1] == "test-123" assert call_args[3] == "Hello, world!" @@ -309,7 +308,6 @@ def test_request_logger_log_outputs_streaming_delta(): mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args.args - # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) assert "Generated response %s%s" in call_args[0] assert call_args[1] == "test-456" assert call_args[2] == " (streaming delta)" @@ -337,7 +335,6 @@ def test_request_logger_log_outputs_streaming_complete(): mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args.args - # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) assert "Generated response %s%s" in call_args[0] assert call_args[1] == "test-789" assert call_args[2] == " (streaming complete)" @@ -400,7 +397,6 @@ def test_request_logger_log_outputs_none_values(): mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args.args - # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) assert "Generated response %s%s" in call_args[0] assert call_args[1] == "test-none" assert call_args[3] == "Test output" @@ -427,7 +423,6 @@ def test_request_logger_log_outputs_empty_output(): mock_logger.info.assert_called_once() call_args = mock_logger.info.call_args.args - # logger.info(format_string, request_id, stream_info, outputs, output_token_ids, finish_reason) assert "Generated response %s%s" in call_args[0] assert call_args[1] == "test-empty" assert call_args[3] == "" @@ -469,17 +464,16 @@ def test_request_logger_log_outputs_integration(): input_call = mock_logger.info.call_args_list[0][0] output_call = mock_logger.info.call_args_list[1][0] - # Check input call: logger.info(format_string, request_id, prompt, params, ...) assert "Received request %s" in input_call[0] assert input_call[1] == "test-integration" - # Check output call: logger.info(format_string, request_id, stream_info, outputs, ...) assert "Generated response %s%s" in output_call[0] assert output_call[1] == "test-integration" def test_streaming_complete_logs_full_text_content(): - """Test that streaming complete logging includes full accumulated text, not just token count.""" + """Test that streaming complete logging includes + full accumulated text, not just token count.""" mock_logger = MagicMock() with patch("vllm.entrypoints.logger.logger", mock_logger): @@ -502,8 +496,8 @@ def test_streaming_complete_logs_full_text_content(): # Verify the logged output is the full text, not a token count format logged_output = call_args[3] assert logged_output == full_response - assert "tokens>" not in logged_output # Ensure it's not the old token count format - assert "streaming_complete" not in logged_output # Ensure it's not the fallback format + assert "tokens>" not in logged_output + assert "streaming_complete" not in logged_output # Verify other parameters assert call_args[1] == "test-streaming-full-text" From 11543a02a795d8deb1a94ee60886db0765d0d6ff Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Sat, 12 Jul 2025 13:30:26 +0400 Subject: [PATCH 11/13] Fixed comment length violation Signed-off-by: Adrian Garcia --- vllm/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e6aaef08504..f58efbc4be1 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -850,7 +850,7 @@ async def chat_completion_stream_generator( previous_texts[i] = current_text all_previous_token_ids[i] = current_token_ids else: - # Update previous_texts for comprehensive logging even in simple content case + # Update for comprehensive logging even in simple case assert previous_texts is not None previous_texts[i] += delta_text From b41c263cb65baf8caca8264dbdd428521964788b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Sun, 13 Jul 2025 18:17:47 +0400 Subject: [PATCH 12/13] Reverted unnecessary formatting changes Signed-off-by: Adrian Garcia --- tests/test_logger.py | 41 +-- vllm/entrypoints/logger.py | 12 +- vllm/entrypoints/openai/cli_args.py | 128 +++---- vllm/entrypoints/openai/serving_chat.py | 346 ++++++++----------- vllm/entrypoints/openai/serving_responses.py | 23 +- 5 files changed, 231 insertions(+), 319 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index 0f9eb7b5785..c29bc47cc9d 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -138,19 +138,16 @@ def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) -@pytest.mark.parametrize( - "unexpected_config", - ( - "Invalid string", - [{ - "version": 1, - "loggers": [] - }], - 0, - ), -) +@pytest.mark.parametrize("unexpected_config", ( + "Invalid string", + [{ + "version": 1, + "loggers": [] + }], + 0, +)) def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( - unexpected_config: Any, ): + unexpected_config: Any): """This test calls _configure_vllm_root_logger again to test custom logging config behavior, however it fails before any change in behavior or configuration occurs.""" @@ -177,16 +174,14 @@ def test_custom_logging_config_is_parsed_and_used_when_provided(): "propagate": False, } }, - "version": 1, + "version": 1 } with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(valid_logging_config)) logging_config_file.flush() - with ( - patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name), - patch("vllm.logger.dictConfig") as dict_config_mock, - ): + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name), patch( + "vllm.logger.dictConfig") as dict_config_mock: _configure_vllm_root_logger() dict_config_mock.assert_called_with(valid_logging_config) @@ -202,7 +197,7 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): "handlers": [], } }, - "version": 1, + "version": 1 } with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(valid_logging_config)) @@ -228,11 +223,11 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): def test_prepare_object_to_dump(): - str_obj = "str" + str_obj = 'str' assert prepare_object_to_dump(str_obj) == "'str'" list_obj = [1, 2, 3] - assert prepare_object_to_dump(list_obj) == "[1, 2, 3]" + assert prepare_object_to_dump(list_obj) == '[1, 2, 3]' dict_obj = {"a": 1, "b": "b"} assert prepare_object_to_dump(dict_obj) in [ @@ -241,9 +236,9 @@ def test_prepare_object_to_dump(): ] set_obj = {1, 2, 3} - assert prepare_object_to_dump(set_obj) == "[1, 2, 3]" + assert prepare_object_to_dump(set_obj) == '[1, 2, 3]' - tuple_obj = ("a", "b", "c") + tuple_obj = ('a', 'b', 'c') assert prepare_object_to_dump(tuple_obj) == "['a', 'b', 'c']" class CustomEnum(enum.Enum): diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 6c2acbacd39..7b6852b063b 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -45,15 +45,11 @@ def log_inputs( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " "prompt_embeds shape: %s, " - "lora_request: %s, prompt_adapter_request: %s.", - request_id, - prompt, - params, - prompt_token_ids, + "lora_request: %s, prompt_adapter_request: %s.", request_id, + prompt, params, prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, - lora_request, - prompt_adapter_request, - ) + lora_request, prompt_adapter_request) + def log_outputs( self, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 1097627eea4..5776ed5db89 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -92,60 +92,56 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--uvicorn-log-level", type=str, default="info", - choices=["debug", "info", "warning", "error", "critical", "trace"], - help="Log level for uvicorn.", - ) - parser.add_argument( - "--disable-uvicorn-access-log", - action="store_true", - help="Disable uvicorn access log.", - ) + choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], + help="Log level for uvicorn.") + parser.add_argument("--disable-uvicorn-access-log", + action='store_true', + help="Disable uvicorn access log.") parser.add_argument("--allow-credentials", - action="store_true", + action='store_true', help="Allow credentials.") parser.add_argument( "--allowed-origins", type=json.loads, default=["*"], - help="Allowed origins.", + help="Allowed origins." ) parser.add_argument( "--allowed-methods", type=json.loads, default=["*"], - help="Allowed methods.", + help="Allowed methods." ) parser.add_argument( "--allowed-headers", type=json.loads, default=["*"], - help="Allowed headers.", + help="Allowed headers." ) parser.add_argument( "--api-key", type=optional_type(str), default=None, help="If provided, the server will require this key " - "to be presented in the header.", + "to be presented in the header." ) parser.add_argument( "--lora-modules", type=optional_type(str), default=None, - nargs="+", + nargs='+', action=LoRAParserAction, help="LoRA module configurations in either 'name=path' format" "or JSON format. " "Example (old format): ``'name=path'`` " "Example (new format): " - '``{"name": "name", "path": "lora_path", ' - '"base_model_name": "id"}``', - ) + "``{\"name\": \"name\", \"path\": \"lora_path\", " + "\"base_model_name\": \"id\"}``") parser.add_argument( "--prompt-adapters", type=optional_type(str), default=None, - nargs="+", + nargs='+', action=PromptAdapterParserAction, help="Prompt adapter configurations in the format name=path. " "Multiple adapters can be specified.", @@ -163,57 +159,57 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default="auto", choices=get_args(ChatTemplateContentFormatOption), - help="The format to render message content within a chat template." - "\n\n" + help='The format to render message content within a chat template.' + '\n\n' '* "string" will render the content as a string. ' 'Example: ``"Hello World"``\n' '* "openai" will render the content as a list of dictionaries, ' "similar to OpenAI schema. " - 'Example: ``[{"type": "text", "text": "Hello world!"}]``', + 'Example: ``[{"type": "text", "text": "Hello world!"}]``' ) parser.add_argument( "--response-role", type=optional_type(str), default="assistant", help="The role name to return if " - "``request.add_generation_prompt=true``.", + "``request.add_generation_prompt=true``." ) parser.add_argument( "--ssl-keyfile", type=optional_type(str), default=None, - help="The file path to the SSL key file.", + help="The file path to the SSL key file." ) parser.add_argument( "--ssl-certfile", type=optional_type(str), default=None, - help="The file path to the SSL cert file.", + help="The file path to the SSL cert file." ) parser.add_argument( "--ssl-ca-certs", type=optional_type(str), default=None, - help="The CA certificates file.", + help="The CA certificates file." ) parser.add_argument( "--enable-ssl-refresh", - action="store_true", + action='store_true', default=False, - help="Refresh SSL Context when SSL certificate files change", + help="Refresh SSL Context when SSL certificate files change" ) parser.add_argument( "--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE), help= - "Whether client certificate is required (see stdlib ssl module's).", + "Whether client certificate is required (see stdlib ssl module's)." ) parser.add_argument( "--root-path", type=optional_type(str), default=None, - help="FastAPI root_path when app is behind a path based routing proxy.", + help="FastAPI root_path when app is behind a path based routing proxy." ) parser.add_argument( "--middleware", @@ -226,37 +222,32 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "If a function is provided, vLLM will add it to the server " "using ``@app.middleware('http')``. " "If a class is provided, vLLM will add it to the server " - "using ``app.add_middleware()``. ", - ) + "using ``app.add_middleware()``. ") parser.add_argument( "--return-tokens-as-token-ids", - action="store_true", + action='store_true', help="When ``--max-logprobs`` is specified, represents single tokens " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.", - ) + "that are not JSON-encodable can be identified.") parser.add_argument( "--disable-frontend-multiprocessing", - action="store_true", + action='store_true', help="If specified, will run the OpenAI frontend server in the same " - "process as the model serving engine.", - ) + "process as the model serving engine.") parser.add_argument( "--enable-request-id-headers", - action="store_true", + action='store_true', help="If specified, API server will add X-Request-Id header to " - "responses.", - ) + "responses.") parser.add_argument( "--enable-auto-tool-choice", - action="store_true", + action='store_true', default=False, help="Enable auto tool choice for supported models. Use " - "``--tool-call-parser`` to specify which parser to use.", - ) + "``--tool-call-parser`` to specify which parser to use.") parser.add_argument( "--expand-tools-even-if-tool-choice-none", - action="store_true", + action='store_true', default=False, deprecated=True, help="Include tool definitions in prompts " @@ -264,8 +255,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "This is a transitional option that will be removed in v0.10.0. " "In v0.10.0, tool definitions will always be included regardless of " "tool_choice setting. Use this flag now to test the new behavior " - "before the breaking change.", - ) + "before the breaking change.") valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( @@ -277,8 +267,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " - "format. Required for ``--enable-auto-tool-choice``.", - ) + "format. Required for ``--enable-auto-tool-choice``.") parser.add_argument( "--tool-parser-plugin", @@ -287,56 +276,49 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Special the tool parser plugin write to parse the model-generated tool" " into OpenAI API format, the name register in this plugin can be used " - "in ``--tool-call-parser``.", - ) + "in ``--tool-call-parser``.") parser.add_argument( "--log-config-file", type=str, default=envs.VLLM_LOGGING_CONFIG_PATH, - help="Path to logging config JSON file for both vllm and uvicorn", - ) + help="Path to logging config JSON file for both vllm and uvicorn") parser = AsyncEngineArgs.add_cli_args(parser) - parser.add_argument( - "--max-log-len", - type=int, - default=None, - help="Max number of prompt characters or prompt " - "ID numbers being printed in log." - " The default of None means unlimited.", - ) + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + ' The default of None means unlimited.') parser.add_argument( "--disable-fastapi-docs", - action="store_true", + action='store_true', default=False, - help= - "Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint.", + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint." ) parser.add_argument( "--enable-prompt-tokens-details", - action="store_true", + action='store_true', default=False, - help="If set to True, enable prompt_tokens_details in usage.", - ) + help="If set to True, enable prompt_tokens_details in usage.") parser.add_argument( "--enable-force-include-usage", - action="store_true", + action='store_true', default=False, help="If set to True, including usage on every request.", ) parser.add_argument( "--enable-server-load-tracking", - action="store_true", + action='store_true', default=False, help= - "If set to True, enable tracking server_load_metrics in the app state.", - ) + "If set to True, enable tracking server_load_metrics in the app state.") parser.add_argument( "--enable-log-outputs", - action="store_true", + action='store_true', default=False, help="If set to True, enable logging of model outputs (generations) " "in addition to the input logging that is enabled by default.", @@ -355,8 +337,8 @@ def validate_parsed_serve_args(args: argparse.Namespace): # Enable auto tool needs a tool call parser to be valid if args.enable_auto_tool_choice and not args.tool_call_parser: - raise TypeError( - "Error: --enable-auto-tool-choice requires --tool-call-parser") + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") if args.enable_prompt_embeds and args.enable_prompt_adapter: raise ValueError( "Cannot use prompt embeds and prompt adapter at the same time.") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f58efbc4be1..939653df074 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -75,8 +75,7 @@ def __init__( models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage, - ) + enable_force_include_usage=enable_force_include_usage) self.response_role = response_role self.chat_template = chat_template @@ -87,7 +86,7 @@ def __init__( self.enable_auto_tools: bool = enable_auto_tools if self.enable_auto_tools: logger.info( - '"auto" tool choice has been enabled please note that while' + "\"auto\" tool choice has been enabled please note that while" " the parallel_tool_calls client option is preset for " "compatibility reasons, it will be ignored.") @@ -105,8 +104,8 @@ def __init__( self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: try: - if tool_parser == "pythonic" and model_config.model.startswith( - "meta-llama/Llama-3.2"): + if (tool_parser == "pythonic" and + model_config.model.startswith("meta-llama/Llama-3.2")): logger.warning( "Llama3.2 models may struggle to emit valid pythonic" " tool calls") @@ -126,11 +125,8 @@ def __init__( if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info( - "Using default chat sampling params from %s: %s", - source, - self.default_sampling_params, - ) + logger.info("Using default chat sampling params from %s: %s", + source, self.default_sampling_params) async def create_chat_completion( self, @@ -183,7 +179,7 @@ async def create_chat_completion( # for hf tokenizers, "auto" tools requires # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( - '"auto" tool choice requires ' + "\"auto\" tool choice requires " "--enable-auto-tool-choice and --tool-call-parser to be set" ) @@ -230,9 +226,8 @@ async def create_chat_completion( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") - request_id = ( - f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" - ) + request_id = "chatcmpl-" \ + f"{self._base_request_id(raw_request, request.request_id)}" request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: @@ -251,26 +246,21 @@ async def create_chat_completion( max_model_len=self.max_model_len, request=request, input_length=len(engine_prompt["prompt_token_ids"]), - default_sampling_params=self.default_sampling_params, - ) + default_sampling_params=self.default_sampling_params) if request.use_beam_search: sampling_params = request.to_beam_search_params( max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( - max_tokens, - self.model_config.logits_processor_pattern, - self.default_sampling_params, - ) + max_tokens, self.model_config.logits_processor_pattern, + self.default_sampling_params) - self._log_inputs( - request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) + self._log_inputs(request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) @@ -299,7 +289,7 @@ async def create_chat_completion( return self.create_error_response(str(e)) assert len(generators) == 1 - (result_generator, ) = generators + result_generator, = generators # Streaming response if request.stream: @@ -311,19 +301,13 @@ async def create_chat_completion( conversation, tokenizer, request_metadata, - enable_force_include_usage=self.enable_force_include_usage, - ) + enable_force_include_usage=self.enable_force_include_usage) + try: return await self.chat_completion_full_generator( - request, - result_generator, - request_id, - model_name, - conversation, - tokenizer, - request_metadata, - ) + request, result_generator, request_id, model_name, + conversation, tokenizer, request_metadata) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -334,7 +318,7 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return request.messages[-1]["role"] @staticmethod - def _bracket_level(s: str, opening="{", closing="}") -> int: + def _bracket_level(s: str, opening='{', closing='}') -> int: """ Calculate the current level of nested brackets in a given string. """ @@ -358,10 +342,10 @@ def _filter_delta_text(delta_text: str, bracket_level = OpenAIServingChat._bracket_level(previous_text) updated_delta, passed_zero = "", False for c in delta_text: - if c == "{": + if c == '{': bracket_level += 1 passed_zero = bracket_level == 0 - elif c == "}": + elif c == '}': bracket_level -= 1 passed_zero = bracket_level == 0 @@ -369,7 +353,7 @@ def _filter_delta_text(delta_text: str, updated_delta += c else: # if a comma is reached at level 0 we can stop - if c == ",": + if c == ',': break return updated_delta, passed_zero @@ -386,7 +370,7 @@ def extract_tool_call_required_streaming( try: obj = partial_json_parser.loads(current_text) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug("not enough tokens to parse into JSON yet") + logger.debug('not enough tokens to parse into JSON yet') obj = None # check if the current text is a valid array @@ -425,15 +409,12 @@ def extract_tool_call_required_streaming( function_name_returned = True delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall( - id=random_tool_call_id(), - function=DeltaFunctionCall( - name=current_tool_call["name"], - arguments=arguments, - ), - index=len(obj) - 1, - type="function", - ) + DeltaToolCall(id=random_tool_call_id(), + function=DeltaFunctionCall( + name=current_tool_call["name"], + arguments=arguments), + index=len(obj) - 1, + type="function") ]) else: @@ -447,10 +428,8 @@ def extract_tool_call_required_streaming( # OpenAI API returns None # instead of name every time name=None, - arguments=delta_text, - ), - index=len(obj) - 1, - ) + arguments=delta_text), + index=len(obj) - 1) ]) else: delta_message = None @@ -534,10 +513,10 @@ async def chat_completion_stream_generator( stream_options = request.stream_options if stream_options: - include_usage = (stream_options.include_usage - or enable_force_include_usage) - include_continuous_usage = (include_usage and - stream_options.continuous_usage_stats) + include_usage = stream_options.include_usage \ + or enable_force_include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats else: include_usage, include_continuous_usage = False, False @@ -567,23 +546,20 @@ async def chat_completion_stream_generator( content="", ), logprobs=None, - finish_reason=None, - ) + finish_reason=None) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name, - ) + model=model_name) # if continuous usage stats are requested, add it if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=num_prompt_tokens, - ) + total_tokens=num_prompt_tokens) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" @@ -592,8 +568,8 @@ async def chat_completion_stream_generator( # last message if request.echo: last_msg_content: Union[str, list[dict[str, str]]] = "" - if (conversation and "content" in conversation[-1] - and conversation[-1].get("role") == role): + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: last_msg_content = conversation[-1]["content"] or "" if last_msg_content: @@ -604,21 +580,18 @@ async def chat_completion_stream_generator( delta=DeltaMessage( content=last_msg_content), logprobs=None, - finish_reason=None, - )) + finish_reason=None)) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name, - ) + model=model_name) if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=num_prompt_tokens, - ) + total_tokens=num_prompt_tokens) data = chunk.model_dump_json( exclude_unset=True) @@ -648,8 +621,8 @@ async def chat_completion_stream_generator( delta_text = output.text - if (not delta_text and not output.token_ids - and not previous_num_tokens[i]): + if not delta_text and not output.token_ids and \ + not previous_num_tokens[i]: # Chunked prefill case, don't return empty chunks continue @@ -702,18 +675,15 @@ async def chat_completion_stream_generator( delta_tool_call = DeltaToolCall( function=DeltaFunctionCall( arguments=delta_text), - index=i, - ) + index=i) else: delta_tool_call = DeltaToolCall( id=random_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, - arguments=delta_text, - ), - index=i, - ) + arguments=delta_text), + index=i) function_name_returned[i] = True delta_message = DeltaMessage(tool_calls=[ @@ -727,9 +697,11 @@ async def chat_completion_stream_generator( fn_name_returned = function_name_returned[i] if self.reasoning_parser: - _, content = ( + _, content = \ reasoning_parser.extract_reasoning_content( - current_text, request)) + current_text, + request + ) else: content = current_text delta_message, function_name_returned[i] = ( @@ -737,8 +709,7 @@ async def chat_completion_stream_generator( previous_text=previous_text, current_text=content, delta_text=delta_text, - function_name_returned=fn_name_returned, - )) + function_name_returned=fn_name_returned)) # update the previous values for the next iteration previous_texts[i] = current_text @@ -766,9 +737,9 @@ async def chat_completion_stream_generator( # set reasoning status to end. # Remove the text and token ids related # to 'reasoning_content'. - if (res.prompt_token_ids - and reasoning_parser.is_reasoning_end( - list(res.prompt_token_ids))): + if res.prompt_token_ids and \ + reasoning_parser.is_reasoning_end( + list(res.prompt_token_ids)): reasoning_end_arr[i] = True current_token_ids = list(output.token_ids) if delta_message and delta_message.content: @@ -783,9 +754,9 @@ async def chat_completion_stream_generator( if reasoning_parser.is_reasoning_end( list(output.token_ids)): reasoning_end_arr[i] = True - current_token_ids = ( + current_token_ids = \ reasoning_parser.extract_content_ids( - list(output.token_ids))) + list(output.token_ids)) if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None @@ -813,8 +784,7 @@ async def chat_completion_stream_generator( previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, - request=request, - )) + request=request)) # when only tool calls elif tool_choice_auto: assert tool_parser is not None @@ -826,8 +796,8 @@ async def chat_completion_stream_generator( previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, delta_token_ids=output.token_ids, - request=request, - )) + request=request)) + # when only reasoning elif self.reasoning_parser: delta_message = (reasoning_parser. @@ -891,8 +861,7 @@ async def chat_completion_stream_generator( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=None, - ) + finish_reason=None) # if the model is finished generating else: @@ -902,24 +871,21 @@ async def chat_completion_stream_generator( # only happens if we are NOT using guided decoding auto_tools_called = False if tool_parser: - auto_tools_called = (len( - tool_parser.prev_tool_call_arr) > 0) - index = (len(tool_parser.prev_tool_call_arr) - - 1 if auto_tools_called else 0) + auto_tools_called = len( + tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr + ) - 1 if auto_tools_called else 0 else: index = 0 - if (self._should_check_for_unstreamed_tool_arg_tokens( - delta_message, output) and tool_parser): + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: latest_delta_len = 0 - if (isinstance( + if ((isinstance( delta_message.tool_calls[0].function, - DeltaFunctionCall, - )) and isinstance( - delta_message.tool_calls[0].function. - arguments, - str, - ): + DeltaFunctionCall)) and isinstance( + delta_message.tool_calls[0].function. + arguments, str)): latest_delta_len = len( delta_message.tool_calls[0].function. arguments) @@ -929,14 +895,13 @@ async def chat_completion_stream_generator( expected_call = json.dumps( tool_parser.prev_tool_call_arr[index].get( "arguments", {}), - ensure_ascii=False, - ) + ensure_ascii=False) # get what we've streamed so far for arguments # for the current tool actual_call = tool_parser.streamed_args_for_tool[ index] - if latest_delta_len > 0: + if (latest_delta_len > 0): actual_call = actual_call[:-latest_delta_len] # check to see if there's anything left to stream @@ -944,12 +909,10 @@ async def chat_completion_stream_generator( actual_call, "", 1) # set that as a delta message delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=index, - function=DeltaFunctionCall( - arguments=remaining_call).model_dump( - exclude_none=True), - ) + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) ]) # Send the finish response for each request.n only once @@ -959,8 +922,7 @@ async def chat_completion_stream_generator( logprobs=logprobs, finish_reason=output.finish_reason if not auto_tools_called else "tool_calls", - stop_reason=output.stop_reason, - ) + stop_reason=output.stop_reason) finish_reason_sent[i] = True @@ -969,8 +931,7 @@ async def chat_completion_stream_generator( object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name, - ) + model=model_name) # handle usage stats if requested & if continuous if include_continuous_usage: @@ -988,11 +949,10 @@ async def chat_completion_stream_generator( # is sent, send the usage if include_usage: completion_tokens = sum(previous_num_tokens) - final_usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens, - ) + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) if self.enable_prompt_tokens_details and num_cached_tokens: final_usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=num_cached_tokens) @@ -1003,10 +963,9 @@ async def chat_completion_stream_generator( created=created_time, choices=[], model=model_name, - usage=final_usage, - ) - final_usage_data = final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True) + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices @@ -1054,6 +1013,7 @@ async def chat_completion_full_generator( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: + created_time = int(time.time()) final_res: Optional[RequestOutput] = None @@ -1105,21 +1065,20 @@ async def chat_completion_full_generator( # if auto tools are not enabled, and a named tool choice using # outlines is not being used - if (not self.enable_auto_tools or not self.tool_parser) and ( - not isinstance(request.tool_choice, - ChatCompletionNamedToolChoiceParam) - and request.tool_choice != "required"): - message = ChatMessage( - role=role, - reasoning_content=reasoning_content, - content=content, - ) + if (not self.enable_auto_tools or not self.tool_parser) and \ + (not isinstance(request.tool_choice, + ChatCompletionNamedToolChoiceParam + ) and request.tool_choice != "required"): + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) # if the request uses tools and specified a tool choice - elif (request.tool_choice and type(request.tool_choice) - is ChatCompletionNamedToolChoiceParam): - tool_call_class = (MistralToolCall if isinstance( - tokenizer, MistralTokenizer) else ToolCall) + elif request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall message = ChatMessage( role=role, reasoning_content=reasoning_content, @@ -1133,8 +1092,8 @@ async def chat_completion_full_generator( ) elif request.tool_choice and request.tool_choice == "required": - tool_call_class = (MistralToolCall if isinstance( - tokenizer, MistralTokenizer) else ToolCall) + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall # the fields of FunctionDefinition are a superset of the # tool call outputs and can be used for parsing @@ -1149,25 +1108,24 @@ async def chat_completion_full_generator( tool_call_class(function=FunctionCall( name=tool_call.name, arguments=json.dumps(tool_call.parameters, - ensure_ascii=False), - )) for tool_call in tool_calls - ], - ) + ensure_ascii=False))) + for tool_call in tool_calls + ]) # if the request doesn't use tool choice # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": - message = ChatMessage( - role=role, - reasoning_content=reasoning_content, - content=content, - ) + + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) # handle when there are tools and tool choice is auto - elif ( - request.tools and - (request.tool_choice == "auto" or request.tool_choice is None) - and self.enable_auto_tools and self.tool_parser): + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + try: tool_parser = self.tool_parser(tokenizer) except RuntimeError as e: @@ -1181,21 +1139,17 @@ async def chat_completion_full_generator( # call. The same is not true for named function calls auto_tools_called = tool_call_info.tools_called if tool_call_info.tools_called: - message = ChatMessage( - role=role, - reasoning_content=reasoning_content, - content=tool_call_info.content, - tool_calls=tool_call_info.tool_calls, - ) + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) else: # FOR NOW make it a chat message; we will have to detect # the type to make it later. - message = ChatMessage( - role=role, - reasoning_content=reasoning_content, - content=content, - ) + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) # undetermined case that is still important to handle else: @@ -1203,11 +1157,9 @@ async def chat_completion_full_generator( "Error in chat_completion_full_generator - cannot determine" " if tools should be extracted. Returning a standard chat " "completion.") - message = ChatMessage( - role=role, - reasoning_content=reasoning_content, - content=content, - ) + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) choice_data = ChatCompletionResponseChoice( index=output.index, @@ -1215,8 +1167,8 @@ async def chat_completion_full_generator( logprobs=logprobs, finish_reason="tool_calls" if auto_tools_called else output.finish_reason if output.finish_reason else "stop", - stop_reason=output.stop_reason, - ) + stop_reason=output.stop_reason) + choices.append(choice_data) if request.echo: @@ -1225,7 +1177,7 @@ async def chat_completion_full_generator( and conversation[-1].get("role") == role): last_msg_content = conversation[-1]["content"] or "" if isinstance(last_msg_content, list): - last_msg_content = "\n".join(msg["text"] + last_msg_content = "\n".join(msg['text'] for msg in last_msg_content) for choice in choices: @@ -1239,11 +1191,10 @@ async def chat_completion_full_generator( num_prompt_tokens += len(final_res.encoder_prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) if self.enable_prompt_tokens_details and final_res.num_cached_tokens: usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=final_res.num_cached_tokens) @@ -1297,12 +1248,9 @@ async def chat_completion_full_generator( return response def _get_top_logprobs( - self, - logprobs: dict[int, Logprob], - top_logprobs: Optional[int], - tokenizer: AnyTokenizer, - should_return_as_token_id: bool, - ) -> list[ChatCompletionLogProb]: + self, logprobs: dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: AnyTokenizer, + should_return_as_token_id: bool) -> list[ChatCompletionLogProb]: return [ ChatCompletionLogProb( token=(token := self._get_decoded_token( @@ -1328,13 +1276,12 @@ def _create_chat_logprobs( """Create OpenAI-style logprobs.""" logprobs_content: list[ChatCompletionLogProbsContent] = [] - should_return_as_token_id = (return_as_token_id - if return_as_token_id is not None else - self.return_tokens_as_token_ids) + should_return_as_token_id = return_as_token_id if \ + return_as_token_id is not None else self.return_tokens_as_token_ids for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] - if (step_top_logprobs is None - or step_top_logprobs.get(token_id) is None): + if step_top_logprobs is None or step_top_logprobs.get( + token_id) is None: token = tokenizer.decode(token_id) if should_return_as_token_id: token = f"token_id:{token_id}" @@ -1360,11 +1307,8 @@ def _create_chat_logprobs( bytes=None if step_decoded is None else list( step_decoded.encode("utf-8", errors="replace")), top_logprobs=self._get_top_logprobs( - step_top_logprobs, - num_output_top_logprobs, - tokenizer, - should_return_as_token_id, - ), + step_top_logprobs, num_output_top_logprobs, + tokenizer, should_return_as_token_id), )) return ChatCompletionLogProbs(content=logprobs_content) @@ -1380,7 +1324,7 @@ def _should_stream_with_auto_tool_parsing(self, choice field indicates that "auto" tool choice should be used. """ return (request.tools and self.tool_parser and self.enable_auto_tools - and request.tool_choice in ["auto", None]) + and request.tool_choice in ['auto', None]) def _should_check_for_unstreamed_tool_arg_tokens( self, diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 757554e848c..0096c37280a 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -89,11 +89,8 @@ def __init__( if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info( - "Using default chat sampling params from %s: %s", - source, - self.default_sampling_params, - ) + logger.info("Using default chat sampling params from %s: %s", + source, self.default_sampling_params) # HACK(woosuk): This is a hack. We should use a better store. # FIXME: This causes a memory leak since we never remove responses @@ -172,13 +169,11 @@ async def create_responses( sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params) - self._log_inputs( - request.request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) + self._log_inputs(request.request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) @@ -198,7 +193,7 @@ async def create_responses( return self.create_error_response(str(e)) assert len(generators) == 1 - (result_generator, ) = generators + result_generator, = generators # Store the input messages. if request.store: @@ -465,7 +460,7 @@ async def cancel_responses( response.status = "cancelled" # Abort the request. - if task := self.background_tasks.get(response_id): + if (task := self.background_tasks.get(response_id)): task.cancel() try: await task From aaa2579a286bb80500dbc8690353519368d5cd77 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Date: Mon, 14 Jul 2025 10:25:35 +0400 Subject: [PATCH 13/13] Run pre-commit hooks to fix a PR issue with some format changes Signed-off-by: Adrian Garcia --- vllm/entrypoints/logger.py | 1 - vllm/entrypoints/openai/cli_args.py | 96 ++++++++++--------------- vllm/entrypoints/openai/serving_chat.py | 16 ++--- 3 files changed, 46 insertions(+), 67 deletions(-) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 7b6852b063b..d90163d051d 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -49,7 +49,6 @@ def log_inputs( prompt, params, prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, lora_request, prompt_adapter_request) - def log_outputs( self, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 5776ed5db89..1651639917f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -100,31 +100,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--allow-credentials", action='store_true', help="Allow credentials.") - parser.add_argument( - "--allowed-origins", - type=json.loads, - default=["*"], - help="Allowed origins." - ) - parser.add_argument( - "--allowed-methods", - type=json.loads, - default=["*"], - help="Allowed methods." - ) - parser.add_argument( - "--allowed-headers", - type=json.loads, - default=["*"], - help="Allowed headers." - ) - parser.add_argument( - "--api-key", - type=optional_type(str), - default=None, - help="If provided, the server will require this key " - "to be presented in the header." - ) + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="Allowed origins.") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="Allowed methods.") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="Allowed headers.") + parser.add_argument("--api-key", + type=optional_type(str), + default=None, + help="If provided, the server will require this key " + "to be presented in the header.") parser.add_argument( "--lora-modules", type=optional_type(str), @@ -165,45 +157,34 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Example: ``"Hello World"``\n' '* "openai" will render the content as a list of dictionaries, ' "similar to OpenAI schema. " - 'Example: ``[{"type": "text", "text": "Hello world!"}]``' - ) - parser.add_argument( - "--response-role", - type=optional_type(str), - default="assistant", - help="The role name to return if " - "``request.add_generation_prompt=true``." - ) - parser.add_argument( - "--ssl-keyfile", - type=optional_type(str), - default=None, - help="The file path to the SSL key file." - ) - parser.add_argument( - "--ssl-certfile", - type=optional_type(str), - default=None, - help="The file path to the SSL cert file." - ) - parser.add_argument( - "--ssl-ca-certs", - type=optional_type(str), - default=None, - help="The CA certificates file." - ) + 'Example: ``[{"type": "text", "text": "Hello world!"}]``') + parser.add_argument("--response-role", + type=optional_type(str), + default="assistant", + help="The role name to return if " + "``request.add_generation_prompt=true``.") + parser.add_argument("--ssl-keyfile", + type=optional_type(str), + default=None, + help="The file path to the SSL key file.") + parser.add_argument("--ssl-certfile", + type=optional_type(str), + default=None, + help="The file path to the SSL cert file.") + parser.add_argument("--ssl-ca-certs", + type=optional_type(str), + default=None, + help="The CA certificates file.") parser.add_argument( "--enable-ssl-refresh", action='store_true', default=False, - help="Refresh SSL Context when SSL certificate files change" - ) + help="Refresh SSL Context when SSL certificate files change") parser.add_argument( "--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE), - help= - "Whether client certificate is required (see stdlib ssl module's)." + help="Whether client certificate is required (see stdlib ssl module's)." ) parser.add_argument( "--root-path", @@ -315,7 +296,8 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', default=False, help= - "If set to True, enable tracking server_load_metrics in the app state.") + "If set to True, enable tracking server_load_metrics in the app state." + ) parser.add_argument( "--enable-log-outputs", action='store_true', diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 939653df074..9bfae5f9586 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -69,13 +69,12 @@ def __init__( enable_force_include_usage: bool = False, enable_log_outputs: bool = False, ) -> None: - super().__init__( - engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage) self.response_role = response_role self.chat_template = chat_template @@ -303,7 +302,6 @@ async def create_chat_completion( request_metadata, enable_force_include_usage=self.enable_force_include_usage) - try: return await self.chat_completion_full_generator( request, result_generator, request_id, model_name, @@ -1013,7 +1011,7 @@ async def chat_completion_full_generator( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: - + created_time = int(time.time()) final_res: Optional[RequestOutput] = None