Skip to content

Commit 7719456

Browse files
[Perf][Frontend] eliminate api_key and x_request_id headers middleware overhead (#19946)
Signed-off-by: Yazan-Sharaya <yazan.sharaya.yes@gmail.com>
1 parent 6209d5d commit 7719456

File tree

4 files changed

+190
-33
lines changed

4 files changed

+190
-33
lines changed

docs/serving/openai_compatible_server.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,6 @@ completion = client.chat.completions.create(
146146
Only `X-Request-Id` HTTP request header is supported for now. It can be enabled
147147
with `--enable-request-id-headers`.
148148

149-
> Note that enablement of the headers can impact performance significantly at high QPS
150-
> rates. We recommend implementing HTTP headers at the router level (e.g. via Istio),
151-
> rather than within the vLLM layer for this reason.
152-
> See [this PR](https://github.com/vllm-project/vllm/pull/11529) for more details.
153-
154149
??? Code
155150

156151
```python
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Tests for middleware that's off by default and can be toggled through
5+
server arguments, mainly --api-key and --enable-request-id-headers.
6+
"""
7+
8+
from http import HTTPStatus
9+
10+
import pytest
11+
import requests
12+
13+
from ...utils import RemoteOpenAIServer
14+
15+
# Use a small embeddings model for faster startup and smaller memory footprint.
16+
# Since we are not testing any chat functionality,
17+
# using a chat capable model is overkill.
18+
MODEL_NAME = "intfloat/multilingual-e5-small"
19+
20+
21+
@pytest.fixture(scope="module")
22+
def server(request: pytest.FixtureRequest):
23+
passed_params = []
24+
if hasattr(request, "param"):
25+
passed_params = request.param
26+
if isinstance(passed_params, str):
27+
passed_params = [passed_params]
28+
29+
args = [
30+
"--task",
31+
"embed",
32+
# use half precision for speed and memory savings in CI environment
33+
"--dtype",
34+
"float16",
35+
"--max-model-len",
36+
"512",
37+
"--enforce-eager",
38+
"--max-num-seqs",
39+
"2",
40+
*passed_params
41+
]
42+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
43+
yield remote_server
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_no_api_token(server: RemoteOpenAIServer):
48+
response = requests.get(server.url_for("v1/models"))
49+
assert response.status_code == HTTPStatus.OK
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_no_request_id_header(server: RemoteOpenAIServer):
54+
response = requests.get(server.url_for("health"))
55+
assert "X-Request-Id" not in response.headers
56+
57+
58+
@pytest.mark.parametrize(
59+
"server",
60+
[["--api-key", "test"]],
61+
indirect=True,
62+
)
63+
@pytest.mark.asyncio
64+
async def test_missing_api_token(server: RemoteOpenAIServer):
65+
response = requests.get(server.url_for("v1/models"))
66+
assert response.status_code == HTTPStatus.UNAUTHORIZED
67+
68+
69+
@pytest.mark.parametrize(
70+
"server",
71+
[["--api-key", "test"]],
72+
indirect=True,
73+
)
74+
@pytest.mark.asyncio
75+
async def test_passed_api_token(server: RemoteOpenAIServer):
76+
response = requests.get(server.url_for("v1/models"),
77+
headers={"Authorization": "Bearer test"})
78+
assert response.status_code == HTTPStatus.OK
79+
80+
81+
@pytest.mark.parametrize(
82+
"server",
83+
[["--api-key", "test"]],
84+
indirect=True,
85+
)
86+
@pytest.mark.asyncio
87+
async def test_not_v1_api_token(server: RemoteOpenAIServer):
88+
# Authorization check is skipped for any paths that
89+
# don't start with /v1 (e.g. /v1/chat/completions).
90+
response = requests.get(server.url_for("health"))
91+
assert response.status_code == HTTPStatus.OK
92+
93+
94+
@pytest.mark.parametrize(
95+
"server",
96+
["--enable-request-id-headers"],
97+
indirect=True,
98+
)
99+
@pytest.mark.asyncio
100+
async def test_enable_request_id_header(server: RemoteOpenAIServer):
101+
response = requests.get(server.url_for("health"))
102+
assert "X-Request-Id" in response.headers
103+
assert len(response.headers.get("X-Request-Id", "")) == 32
104+
105+
106+
@pytest.mark.parametrize(
107+
"server",
108+
["--enable-request-id-headers"],
109+
indirect=True,
110+
)
111+
@pytest.mark.asyncio
112+
async def test_custom_request_id_header(server: RemoteOpenAIServer):
113+
response = requests.get(server.url_for("health"),
114+
headers={"X-Request-Id": "Custom"})
115+
assert "X-Request-Id" in response.headers
116+
assert response.headers.get("X-Request-Id") == "Custom"

vllm/entrypoints/openai/api_server.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import tempfile
1515
import uuid
1616
from argparse import Namespace
17-
from collections.abc import AsyncIterator
17+
from collections.abc import AsyncIterator, Awaitable
1818
from contextlib import asynccontextmanager
1919
from functools import partial
2020
from http import HTTPStatus
@@ -30,8 +30,9 @@
3030
from prometheus_client import make_asgi_app
3131
from prometheus_fastapi_instrumentator import Instrumentator
3232
from starlette.concurrency import iterate_in_threadpool
33-
from starlette.datastructures import State
33+
from starlette.datastructures import URL, Headers, MutableHeaders, State
3434
from starlette.routing import Mount
35+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
3536
from typing_extensions import assert_never
3637

3738
import vllm.envs as envs
@@ -1061,6 +1062,74 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
10611062
return None
10621063

10631064

1065+
class AuthenticationMiddleware:
1066+
"""
1067+
Pure ASGI middleware that authenticates each request by checking
1068+
if the Authorization header exists and equals "Bearer {api_key}".
1069+
1070+
Notes
1071+
-----
1072+
There are two cases in which authentication is skipped:
1073+
1. The HTTP method is OPTIONS.
1074+
2. The request path doesn't start with /v1 (e.g. /health).
1075+
"""
1076+
1077+
def __init__(self, app: ASGIApp, api_token: str) -> None:
1078+
self.app = app
1079+
self.api_token = api_token
1080+
1081+
def __call__(self, scope: Scope, receive: Receive,
1082+
send: Send) -> Awaitable[None]:
1083+
if scope["type"] not in ("http",
1084+
"websocket") or scope["method"] == "OPTIONS":
1085+
# scope["type"] can be "lifespan" or "startup" for example,
1086+
# in which case we don't need to do anything
1087+
return self.app(scope, receive, send)
1088+
root_path = scope.get("root_path", "")
1089+
url_path = URL(scope=scope).path.removeprefix(root_path)
1090+
headers = Headers(scope=scope)
1091+
# Type narrow to satisfy mypy.
1092+
if url_path.startswith("/v1") and headers.get(
1093+
"Authorization") != f"Bearer {self.api_token}":
1094+
response = JSONResponse(content={"error": "Unauthorized"},
1095+
status_code=401)
1096+
return response(scope, receive, send)
1097+
return self.app(scope, receive, send)
1098+
1099+
1100+
class XRequestIdMiddleware:
1101+
"""
1102+
Middleware the set's the X-Request-Id header for each response
1103+
to a random uuid4 (hex) value if the header isn't already
1104+
present in the request, otherwise use the provided request id.
1105+
"""
1106+
1107+
def __init__(self, app: ASGIApp) -> None:
1108+
self.app = app
1109+
1110+
def __call__(self, scope: Scope, receive: Receive,
1111+
send: Send) -> Awaitable[None]:
1112+
if scope["type"] not in ("http", "websocket"):
1113+
return self.app(scope, receive, send)
1114+
1115+
# Extract the request headers.
1116+
request_headers = Headers(scope=scope)
1117+
1118+
async def send_with_request_id(message: Message) -> None:
1119+
"""
1120+
Custom send function to mutate the response headers
1121+
and append X-Request-Id to it.
1122+
"""
1123+
if message["type"] == "http.response.start":
1124+
response_headers = MutableHeaders(raw=message["headers"])
1125+
request_id = request_headers.get("X-Request-Id",
1126+
uuid.uuid4().hex)
1127+
response_headers.append("X-Request-Id", request_id)
1128+
await send(message)
1129+
1130+
return self.app(scope, receive, send_with_request_id)
1131+
1132+
10641133
def build_app(args: Namespace) -> FastAPI:
10651134
if args.disable_fastapi_docs:
10661135
app = FastAPI(openapi_url=None,
@@ -1108,33 +1177,10 @@ async def validation_exception_handler(_: Request,
11081177

11091178
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
11101179
if token := args.api_key or envs.VLLM_API_KEY:
1111-
1112-
@app.middleware("http")
1113-
async def authentication(request: Request, call_next):
1114-
if request.method == "OPTIONS":
1115-
return await call_next(request)
1116-
url_path = request.url.path
1117-
if app.root_path and url_path.startswith(app.root_path):
1118-
url_path = url_path[len(app.root_path):]
1119-
if not url_path.startswith("/v1"):
1120-
return await call_next(request)
1121-
if request.headers.get("Authorization") != "Bearer " + token:
1122-
return JSONResponse(content={"error": "Unauthorized"},
1123-
status_code=401)
1124-
return await call_next(request)
1180+
app.add_middleware(AuthenticationMiddleware, api_token=token)
11251181

11261182
if args.enable_request_id_headers:
1127-
logger.warning(
1128-
"CAUTION: Enabling X-Request-Id headers in the API Server. "
1129-
"This can harm performance at high QPS.")
1130-
1131-
@app.middleware("http")
1132-
async def add_request_id(request: Request, call_next):
1133-
request_id = request.headers.get(
1134-
"X-Request-Id") or uuid.uuid4().hex
1135-
response = await call_next(request)
1136-
response.headers["X-Request-Id"] = request_id
1137-
return response
1183+
app.add_middleware(XRequestIdMiddleware)
11381184

11391185
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
11401186
logger.warning("CAUTION: Enabling log response in the API Server. "

vllm/entrypoints/openai/cli_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
216216
"--enable-request-id-headers",
217217
action="store_true",
218218
help="If specified, API server will add X-Request-Id header to "
219-
"responses. Caution: this hurts performance at high QPS.")
219+
"responses.")
220220
parser.add_argument(
221221
"--enable-auto-tool-choice",
222222
action="store_true",

0 commit comments

Comments
 (0)