Skip to content

Commit eccdc83

Browse files
authored
[V1][P/D] An native implementation of xPyD based on P2P NCCL (#18242)
Signed-off-by: Abatom <abzhonghua@gmail.com>
1 parent 5f52a84 commit eccdc83

File tree

8 files changed

+1780
-0
lines changed

8 files changed

+1780
-0
lines changed

docs/design/v1/p2p_nccl_connector.md

Lines changed: 337 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import os
4+
import socket
5+
import threading
6+
import uuid
7+
8+
import aiohttp
9+
import msgpack
10+
import zmq
11+
from quart import Quart, make_response, request
12+
13+
count = 0
14+
prefill_instances: dict[str, str] = {} # http_address: zmq_address
15+
decode_instances: dict[str, str] = {} # http_address: zmq_address
16+
17+
prefill_cv = threading.Condition()
18+
decode_cv = threading.Condition()
19+
20+
21+
def _listen_for_register(poller, router_socket):
22+
while True:
23+
socks = dict(poller.poll())
24+
if router_socket in socks:
25+
remote_address, message = router_socket.recv_multipart()
26+
# data: {"type": "P", "http_address": "ip:port",
27+
# "zmq_address": "ip:port"}
28+
data = msgpack.loads(message)
29+
if data["type"] == "P":
30+
global prefill_instances
31+
global prefill_cv
32+
with prefill_cv:
33+
prefill_instances[data["http_address"]] = data["zmq_address"]
34+
elif data["type"] == "D":
35+
global decode_instances
36+
global decode_cv
37+
with decode_cv:
38+
decode_instances[data["http_address"]] = data["zmq_address"]
39+
else:
40+
print(
41+
"Unexpected, Received message from %s, data: %s",
42+
remote_address,
43+
data,
44+
)
45+
46+
47+
def start_service_discovery(hostname, port):
48+
if not hostname:
49+
hostname = socket.gethostname()
50+
if port == 0:
51+
raise ValueError("Port cannot be 0")
52+
53+
context = zmq.Context()
54+
router_socket = context.socket(zmq.ROUTER)
55+
router_socket.bind(f"tcp://{hostname}:{port}")
56+
57+
poller = zmq.Poller()
58+
poller.register(router_socket, zmq.POLLIN)
59+
60+
_listener_thread = threading.Thread(
61+
target=_listen_for_register, args=[poller, router_socket], daemon=True
62+
)
63+
_listener_thread.start()
64+
return _listener_thread
65+
66+
67+
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
68+
69+
app = Quart(__name__)
70+
71+
72+
def random_uuid() -> str:
73+
return str(uuid.uuid4().hex)
74+
75+
76+
async def forward_request(url, data, request_id):
77+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
78+
headers = {
79+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
80+
"X-Request-Id": request_id,
81+
}
82+
async with session.post(url=url, json=data, headers=headers) as response:
83+
if response.status == 200:
84+
if True:
85+
async for chunk_bytes in response.content.iter_chunked(1024):
86+
yield chunk_bytes
87+
else:
88+
content = await response.read()
89+
yield content
90+
91+
92+
@app.route("/v1/completions", methods=["POST"])
93+
async def handle_request():
94+
try:
95+
original_request_data = await request.get_json()
96+
97+
prefill_request = original_request_data.copy()
98+
# change max_tokens = 1 to let it only do prefill
99+
prefill_request["max_tokens"] = 1
100+
101+
global count
102+
global prefill_instances
103+
global prefill_cv
104+
with prefill_cv:
105+
prefill_list = list(prefill_instances.items())
106+
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
107+
108+
global decode_instances
109+
global decode_cv
110+
with decode_cv:
111+
decode_list = list(decode_instances.items())
112+
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
113+
114+
print(
115+
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
116+
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
117+
f"ZMQ:{decode_zmq_addr}]"
118+
)
119+
count += 1
120+
121+
request_id = (
122+
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
123+
f"{decode_zmq_addr}_{random_uuid()}"
124+
)
125+
126+
# finish prefill
127+
async for _ in forward_request(
128+
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
129+
):
130+
continue
131+
132+
# return decode
133+
generator = forward_request(
134+
f"http://{decode_addr}/v1/completions", original_request_data, request_id
135+
)
136+
response = await make_response(generator)
137+
response.timeout = None
138+
139+
return response
140+
141+
except Exception as e:
142+
import sys
143+
import traceback
144+
145+
exc_info = sys.exc_info()
146+
print("Error occurred in disagg prefill proxy server")
147+
print(e)
148+
print("".join(traceback.format_exception(*exc_info)))
149+
150+
151+
if __name__ == "__main__":
152+
t = start_service_discovery("0.0.0.0", 30001)
153+
app.run(host="0.0.0.0", port=10001)
154+
t.join()

vllm/distributed/device_communicators/pynccl_wrapper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,14 @@ def ncclGetUniqueId(self) -> ncclUniqueId:
272272
ctypes.byref(unique_id)))
273273
return unique_id
274274

275+
def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
276+
if len(data) != 128:
277+
raise ValueError(
278+
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")
279+
unique_id = ncclUniqueId()
280+
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
281+
return unique_id
282+
275283
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
276284
rank: int) -> ncclComm_t:
277285
comm = ncclComm_t()

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ def create_connector_v1(
112112
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
113113
"SharedStorageConnector")
114114

115+
KVConnectorFactory.register_connector(
116+
"P2pNcclConnector",
117+
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
118+
"P2pNcclConnector")
119+
115120
KVConnectorFactory.register_connector(
116121
"LMCacheConnectorV1",
117122
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",

vllm/distributed/kv_transfer/kv_connector/v1/p2p/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)