|
14 | 14 | import tempfile
|
15 | 15 | import uuid
|
16 | 16 | from argparse import Namespace
|
17 |
| -from collections.abc import AsyncIterator |
| 17 | +from collections.abc import AsyncIterator, Awaitable |
18 | 18 | from contextlib import asynccontextmanager
|
19 | 19 | from functools import partial
|
20 | 20 | from http import HTTPStatus
|
|
30 | 30 | from prometheus_client import make_asgi_app
|
31 | 31 | from prometheus_fastapi_instrumentator import Instrumentator
|
32 | 32 | from starlette.concurrency import iterate_in_threadpool
|
33 |
| -from starlette.datastructures import State |
| 33 | +from starlette.datastructures import URL, Headers, MutableHeaders, State |
34 | 34 | from starlette.routing import Mount
|
| 35 | +from starlette.types import ASGIApp, Message, Receive, Scope, Send |
35 | 36 | from typing_extensions import assert_never
|
36 | 37 |
|
37 | 38 | import vllm.envs as envs
|
@@ -1061,6 +1062,74 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
|
1061 | 1062 | return None
|
1062 | 1063 |
|
1063 | 1064 |
|
| 1065 | +class AuthenticationMiddleware: |
| 1066 | + """ |
| 1067 | + Pure ASGI middleware that authenticates each request by checking |
| 1068 | + if the Authorization header exists and equals "Bearer {api_key}". |
| 1069 | +
|
| 1070 | + Notes |
| 1071 | + ----- |
| 1072 | + There are two cases in which authentication is skipped: |
| 1073 | + 1. The HTTP method is OPTIONS. |
| 1074 | + 2. The request path doesn't start with /v1 (e.g. /health). |
| 1075 | + """ |
| 1076 | + |
| 1077 | + def __init__(self, app: ASGIApp, api_token: str) -> None: |
| 1078 | + self.app = app |
| 1079 | + self.api_token = api_token |
| 1080 | + |
| 1081 | + def __call__(self, scope: Scope, receive: Receive, |
| 1082 | + send: Send) -> Awaitable[None]: |
| 1083 | + if scope["type"] not in ("http", |
| 1084 | + "websocket") or scope["method"] == "OPTIONS": |
| 1085 | + # scope["type"] can be "lifespan" or "startup" for example, |
| 1086 | + # in which case we don't need to do anything |
| 1087 | + return self.app(scope, receive, send) |
| 1088 | + root_path = scope.get("root_path", "") |
| 1089 | + url_path = URL(scope=scope).path.removeprefix(root_path) |
| 1090 | + headers = Headers(scope=scope) |
| 1091 | + # Type narrow to satisfy mypy. |
| 1092 | + if url_path.startswith("/v1") and headers.get( |
| 1093 | + "Authorization") != f"Bearer {self.api_token}": |
| 1094 | + response = JSONResponse(content={"error": "Unauthorized"}, |
| 1095 | + status_code=401) |
| 1096 | + return response(scope, receive, send) |
| 1097 | + return self.app(scope, receive, send) |
| 1098 | + |
| 1099 | + |
| 1100 | +class XRequestIdMiddleware: |
| 1101 | + """ |
| 1102 | + Middleware the set's the X-Request-Id header for each response |
| 1103 | + to a random uuid4 (hex) value if the header isn't already |
| 1104 | + present in the request, otherwise use the provided request id. |
| 1105 | + """ |
| 1106 | + |
| 1107 | + def __init__(self, app: ASGIApp) -> None: |
| 1108 | + self.app = app |
| 1109 | + |
| 1110 | + def __call__(self, scope: Scope, receive: Receive, |
| 1111 | + send: Send) -> Awaitable[None]: |
| 1112 | + if scope["type"] not in ("http", "websocket"): |
| 1113 | + return self.app(scope, receive, send) |
| 1114 | + |
| 1115 | + # Extract the request headers. |
| 1116 | + request_headers = Headers(scope=scope) |
| 1117 | + |
| 1118 | + async def send_with_request_id(message: Message) -> None: |
| 1119 | + """ |
| 1120 | + Custom send function to mutate the response headers |
| 1121 | + and append X-Request-Id to it. |
| 1122 | + """ |
| 1123 | + if message["type"] == "http.response.start": |
| 1124 | + response_headers = MutableHeaders(raw=message["headers"]) |
| 1125 | + request_id = request_headers.get("X-Request-Id", |
| 1126 | + uuid.uuid4().hex) |
| 1127 | + response_headers.append("X-Request-Id", request_id) |
| 1128 | + await send(message) |
| 1129 | + |
| 1130 | + return self.app(scope, receive, send_with_request_id) |
| 1131 | + |
| 1132 | + |
1064 | 1133 | def build_app(args: Namespace) -> FastAPI:
|
1065 | 1134 | if args.disable_fastapi_docs:
|
1066 | 1135 | app = FastAPI(openapi_url=None,
|
@@ -1108,33 +1177,10 @@ async def validation_exception_handler(_: Request,
|
1108 | 1177 |
|
1109 | 1178 | # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
|
1110 | 1179 | if token := args.api_key or envs.VLLM_API_KEY:
|
1111 |
| - |
1112 |
| - @app.middleware("http") |
1113 |
| - async def authentication(request: Request, call_next): |
1114 |
| - if request.method == "OPTIONS": |
1115 |
| - return await call_next(request) |
1116 |
| - url_path = request.url.path |
1117 |
| - if app.root_path and url_path.startswith(app.root_path): |
1118 |
| - url_path = url_path[len(app.root_path):] |
1119 |
| - if not url_path.startswith("/v1"): |
1120 |
| - return await call_next(request) |
1121 |
| - if request.headers.get("Authorization") != "Bearer " + token: |
1122 |
| - return JSONResponse(content={"error": "Unauthorized"}, |
1123 |
| - status_code=401) |
1124 |
| - return await call_next(request) |
| 1180 | + app.add_middleware(AuthenticationMiddleware, api_token=token) |
1125 | 1181 |
|
1126 | 1182 | if args.enable_request_id_headers:
|
1127 |
| - logger.warning( |
1128 |
| - "CAUTION: Enabling X-Request-Id headers in the API Server. " |
1129 |
| - "This can harm performance at high QPS.") |
1130 |
| - |
1131 |
| - @app.middleware("http") |
1132 |
| - async def add_request_id(request: Request, call_next): |
1133 |
| - request_id = request.headers.get( |
1134 |
| - "X-Request-Id") or uuid.uuid4().hex |
1135 |
| - response = await call_next(request) |
1136 |
| - response.headers["X-Request-Id"] = request_id |
1137 |
| - return response |
| 1183 | + app.add_middleware(XRequestIdMiddleware) |
1138 | 1184 |
|
1139 | 1185 | if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
|
1140 | 1186 | logger.warning("CAUTION: Enabling log response in the API Server. "
|
|
0 commit comments