Skip to content

Commit 83ec83a

Browse files
committed
move nixl sidechannel to own entrypoint
Signed-off-by: Will Eaton <weaton@redhat.com> push missing side channel server Signed-off-by: Will Eaton <weaton@redhat.com>
1 parent 91af1cd commit 83ec83a

File tree

2 files changed

+125
-26
lines changed

2 files changed

+125
-26
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import asyncio
5+
import threading
6+
from typing import Any, Optional
7+
8+
import uvicorn
9+
from fastapi import FastAPI
10+
11+
from vllm import envs
12+
from vllm.config import VllmConfig
13+
from vllm.entrypoints.launcher import serve_http
14+
from vllm.logger import init_logger
15+
16+
logger = init_logger(__name__)
17+
18+
19+
class NixlSideChannelServer:
20+
21+
def __init__(self, vllm_config: VllmConfig, host: str, port: int):
22+
self.vllm_config = vllm_config
23+
self.host = host
24+
self.port = port
25+
self.app = FastAPI(title="vLLM NIXL Side Channel Server")
26+
self.server = None
27+
self.server_thread = None
28+
self._setup_routes()
29+
30+
def _setup_routes(self):
31+
32+
@self.app.get("/get_kv_connector_metadata")
33+
@self.app.get("/get_kv_connector_metadata/{dp_rank}")
34+
@self.app.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}")
35+
async def get_kv_connector_metadata(dp_rank: Optional[int] = None,
36+
tp_rank: Optional[int] = None):
37+
kv_meta: Optional[dict[str, dict[str, dict[str, Any]]]] = (
38+
self.vllm_config.cache_config.transfer_handshake_metadata)
39+
40+
if kv_meta is None:
41+
return None
42+
43+
if dp_rank is not None:
44+
if dp_rank not in kv_meta:
45+
return {}
46+
dp_data = kv_meta[dp_rank]
47+
48+
if tp_rank is not None:
49+
if tp_rank not in dp_data:
50+
return {}
51+
return {dp_rank: {tp_rank: dp_data[tp_rank]}}
52+
else:
53+
return {dp_rank: dp_data}
54+
55+
return kv_meta
56+
57+
async def start_async(self):
58+
if self.server is not None:
59+
logger.warning("Side channel server is already running")
60+
return
61+
62+
logger.info("Starting NIXL side channel server on %s:%s",
63+
self.host, self.port)
64+
65+
# use uvicorn directly to avoid dependency on engine_client
66+
config = uvicorn.Config(
67+
app=self.app,
68+
host=self.host,
69+
port=self.port,
70+
log_level="info",
71+
access_log=True,
72+
)
73+
self.server = uvicorn.Server(config)
74+
75+
# start the server in a background task
76+
asyncio.create_task(self.server.serve())
77+
logger.info("NIXL side channel server started successfully")
78+
79+
async def stop_async(self):
80+
if self.server is not None:
81+
logger.info("Stopping NIXL side channel server")
82+
try:
83+
self.server.should_exit = True
84+
await asyncio.sleep(1) # give it time to shutdown
85+
except Exception as e:
86+
logger.warning("Error during side channel server shutdown: %s", e)
87+
self.server = None
88+
logger.info("NIXL side channel server stopped")
89+
90+
91+
def should_start_nixl_side_channel_server(vllm_config: VllmConfig) -> bool:
92+
if vllm_config.kv_transfer_config is None:
93+
return False
94+
95+
return vllm_config.kv_transfer_config.kv_connector == "NixlConnector"
96+
97+
98+
async def start_nixl_side_channel_server_if_needed(
99+
vllm_config: VllmConfig) -> Optional[NixlSideChannelServer]:
100+
if not should_start_nixl_side_channel_server(vllm_config):
101+
return None
102+
103+
side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
104+
side_channel_port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT
105+
106+
logger.info("Starting NIXL side channel metadata server on %s:%d",
107+
side_channel_host, side_channel_port)
108+
109+
server = NixlSideChannelServer(
110+
vllm_config, side_channel_host, side_channel_port)
111+
await server.start_async()
112+
return server

vllm/entrypoints/openai/api_server.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@
9393
from vllm.entrypoints.openai.serving_transcription import (
9494
OpenAIServingTranscription, OpenAIServingTranslation)
9595
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
96+
from vllm.entrypoints.nixl_side_channel_server import (
97+
start_nixl_side_channel_server_if_needed)
9698
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
9799
with_cancellation)
98100
from vllm.logger import init_logger
@@ -916,32 +918,6 @@ async def show_server_info(raw_request: Request):
916918
server_info = {"vllm_config": str(raw_request.app.state.vllm_config)}
917919
return JSONResponse(content=server_info)
918920

919-
@router.get("/get_kv_connector_metadata")
920-
@router.get("/get_kv_connector_metadata/{dp_rank}")
921-
@router.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}")
922-
async def get_kv_connector_metadata(raw_request: Request,
923-
dp_rank: Optional[int] = None,
924-
tp_rank: Optional[int] = None):
925-
kv_meta: Optional[dict[str, dict[str, dict[str, Any]]]] = (
926-
raw_request.app.state.vllm_config.cache_config.
927-
transfer_handshake_metadata)
928-
929-
if kv_meta is None:
930-
return None
931-
932-
if dp_rank is not None:
933-
if dp_rank not in kv_meta:
934-
return {}
935-
dp_data = kv_meta[dp_rank]
936-
937-
if tp_rank is not None:
938-
if tp_rank not in dp_data:
939-
return {}
940-
return {dp_rank: {tp_rank: dp_data[tp_rank]}}
941-
else:
942-
return {dp_rank: dp_data}
943-
944-
return kv_meta
945921

946922
@router.post("/reset_prefix_cache")
947923
async def reset_prefix_cache(raw_request: Request):
@@ -1474,6 +1450,12 @@ async def run_server_worker(listen_address,
14741450
vllm_config = await engine_client.get_vllm_config()
14751451
await init_app_state(engine_client, vllm_config, app.state, args)
14761452

1453+
nixl_side_channel_server = None
1454+
try:
1455+
nixl_side_channel_server = await start_nixl_side_channel_server_if_needed(vllm_config)
1456+
except Exception as e:
1457+
logger.warning("Failed to start NIXL side channel server: %s", e)
1458+
14771459
logger.info("Starting vLLM API server %d on %s", server_index,
14781460
listen_address)
14791461
shutdown_task = await serve_http(
@@ -1498,6 +1480,11 @@ async def run_server_worker(listen_address,
14981480
try:
14991481
await shutdown_task
15001482
finally:
1483+
if nixl_side_channel_server is not None:
1484+
try:
1485+
await nixl_side_channel_server.stop_async()
1486+
except Exception as e:
1487+
logger.warning("Error stopping NIXL side channel server: %s", e)
15011488
sock.close()
15021489

15031490

0 commit comments

Comments
 (0)