Skip to content

Commit 701df75

Browse files
committed
pre-commit fixes
Signed-off-by: Will Eaton <weaton@redhat.com> Signed-off-by: Will Eaton <weaton@redhat.com>
1 parent ed66835 commit 701df75

File tree

9 files changed

+285
-268
lines changed

9 files changed

+285
-268
lines changed

tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py

Lines changed: 143 additions & 117 deletions
Large diffs are not rendered by default.

vllm/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,7 +1542,7 @@ class CacheConfig:
15421542

15431543
transfer_handshake_metadata: Optional[dict[int, dict[
15441544
int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False)
1545-
"""Metadata for the KV connector handshake. Structure: dp_rank -> tp_rank -> metadata"""
1545+
"""Metadata for KV connector handshake. Structure: dp_rank -> tp_rank"""
15461546

15471547
def compute_hash(self) -> str:
15481548
"""
@@ -4633,8 +4633,8 @@ def __post_init__(self):
46334633
if self.kv_events_config is not None:
46344634
# Hybrid KV cache manager is not compatible with KV events.
46354635
self.scheduler_config.disable_hybrid_kv_cache_manager = True
4636-
4637-
if (self.kv_transfer_config is not None
4636+
4637+
if (self.kv_transfer_config is not None
46384638
and self.kv_transfer_config.is_kv_transfer_instance):
46394639
from collections import defaultdict
46404640
self.cache_config.transfer_handshake_metadata = defaultdict(dict)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 87 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from collections.abc import Iterator
1414
from concurrent.futures import Future, ThreadPoolExecutor
1515
from dataclasses import dataclass
16-
from typing import TYPE_CHECKING, Any, Dict, Optional
16+
from typing import TYPE_CHECKING, Any, Optional
1717
from urllib.error import HTTPError, URLError
1818
from urllib.request import Request as URLRequest
1919
from urllib.request import urlopen
@@ -81,81 +81,83 @@ class ReqMeta:
8181

8282

8383
class HandshakeStrategy(ABC):
84-
85-
def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int,
84+
85+
def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int,
8686
side_channel_port: int, engine_id: str):
8787
self.nixl_wrapper = nixl_wrapper
8888
self.tp_rank = tp_rank
8989
self.tp_size = tp_size
9090
self.side_channel_port = side_channel_port
9191
self.engine_id = engine_id
92-
92+
9393
@abstractmethod
94-
def initiate_handshake(self, host: str, port: int,
95-
remote_tp_size: int) -> Dict[int, str]:
94+
def initiate_handshake(self, host: str, port: int,
95+
remote_tp_size: int) -> dict[int, str]:
9696
pass
97-
97+
9898
@abstractmethod
9999
def setup_listener(self, metadata: NixlAgentMetadata) -> None:
100100
pass
101-
101+
102102
@abstractmethod
103103
def cleanup(self) -> None:
104104
pass
105105

106106

107107
class ZmqHandshakeStrategy(HandshakeStrategy):
108-
108+
109109
def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int,
110-
side_channel_port: int, engine_id: str,
110+
side_channel_port: int, engine_id: str,
111111
add_remote_agent_func):
112-
super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, engine_id)
112+
super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port,
113+
engine_id)
113114
self.add_remote_agent_func = add_remote_agent_func
114115
self._listener_thread: Optional[threading.Thread] = None
115-
self._tp_size_mapping: Dict[str, int] = {engine_id: tp_size}
116-
117-
def initiate_handshake(self, host: str, port: int,
118-
remote_tp_size: int) -> Dict[int, str]:
116+
self._tp_size_mapping: dict[str, int] = {engine_id: tp_size}
117+
118+
def initiate_handshake(self, host: str, port: int,
119+
remote_tp_size: int) -> dict[int, str]:
119120
start_time = time.perf_counter()
120-
121+
121122
def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
122123
with self._zmq_ctx(zmq.REQ, path) as sock:
123124
sock.send(GET_META_MSG)
124125
metadata_bytes = sock.recv()
125126
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
126127
metadata = decoder.decode(metadata_bytes)
127128
got_metadata_time = time.perf_counter()
128-
129+
129130
# Register Remote agent
130-
agent_name = self.add_remote_agent_func(metadata, rank, remote_tp_size)
131+
agent_name = self.add_remote_agent_func(
132+
metadata, rank, remote_tp_size)
131133
setup_agent_time = time.perf_counter()
132-
134+
133135
logger.debug("NIXL handshake: get metadata took: %s",
134-
got_metadata_time - start_time)
135-
logger.debug("NIXL handshake: add agent took: %s",
136-
setup_agent_time - got_metadata_time)
136+
got_metadata_time - start_time)
137+
logger.debug("NIXL handshake: add agent took: %s",
138+
setup_agent_time - got_metadata_time)
137139
return metadata, agent_name
138-
140+
139141
# Handshake with remote agent-rank0 first to get the tp_size of remote
140142
path = make_zmq_path("tcp", host, port)
141143
logger.debug("Querying master rank metadata on path: %s", path)
142144
metadata, agent_name_0 = handshake(path, 0)
143-
145+
144146
agents = {0: agent_name_0}
145-
147+
146148
# Handshake only with the other TP remote the current local rank will
147149
# pull from. With homogeneous TP it happens to be the same rank_i.
148150
tp_ratio = self._tp_size_mapping[self.engine_id] // remote_tp_size
149151
p_remote_rank = self.tp_rank // tp_ratio
150152
if p_remote_rank > 0:
151153
path = make_zmq_path("tcp", host, port + p_remote_rank)
152154
logger.debug("Querying metadata on path: %s at remote rank %s",
153-
path, p_remote_rank)
155+
path, p_remote_rank)
154156
_, agent_name = handshake(path, p_remote_rank)
155157
agents[p_remote_rank] = agent_name
156-
158+
157159
return agents
158-
160+
159161
def setup_listener(self, metadata: NixlAgentMetadata) -> None:
160162
ready_event = threading.Event()
161163
self._listener_thread = threading.Thread(
@@ -165,20 +167,21 @@ def setup_listener(self, metadata: NixlAgentMetadata) -> None:
165167
name="nixl_handshake_listener")
166168
self._listener_thread.start()
167169
ready_event.wait()
168-
170+
169171
def cleanup(self) -> None:
170172
if self._listener_thread:
171173
self._listener_thread.join(timeout=0)
172-
174+
173175
@staticmethod
174176
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
175-
ready_event: threading.Event, base_port: int,
176-
tp_rank: int):
177+
ready_event: threading.Event, base_port: int,
178+
tp_rank: int):
177179
encoder = msgspec.msgpack.Encoder()
178180
encoded_data = encoder.encode(metadata)
179181
size_in_bytes = len(encoded_data)
180-
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", size_in_bytes)
181-
182+
logger.debug("Size of encoded NixlAgentMetadata: %s bytes",
183+
size_in_bytes)
184+
182185
# Listen for new requests for metadata
183186
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
184187
path = make_zmq_path("tcp", host, base_port + tp_rank)
@@ -188,97 +191,109 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
188191
while True:
189192
identity, _, msg = sock.recv_multipart()
190193
if msg != GET_META_MSG:
191-
logger.warning("Connection listener got unexpected message %s", msg)
194+
logger.warning(
195+
"Connection listener got unexpected message %s", msg)
192196
sock.send_multipart((identity, b"", encoded_data))
193-
197+
194198
@staticmethod
195199
@contextlib.contextmanager
196200
def _zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
197201
if socket_type not in (zmq.ROUTER, zmq.REQ):
198202
raise ValueError(f"Unexpected socket type: {socket_type}")
199-
203+
200204
ctx: Optional[zmq.Context] = None
201205
try:
202206
ctx = zmq.Context()
203-
yield make_zmq_socket(ctx=ctx, path=addr, socket_type=socket_type,
204-
bind=socket_type == zmq.ROUTER)
207+
yield make_zmq_socket(ctx=ctx,
208+
path=addr,
209+
socket_type=socket_type,
210+
bind=socket_type == zmq.ROUTER)
205211
finally:
206212
if ctx is not None:
207213
ctx.destroy(linger=0)
208214

209215

210216
class HttpHandshakeStrategy(HandshakeStrategy):
211-
217+
212218
def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int,
213219
side_channel_port: int, engine_id: str,
214220
add_remote_agent_func):
215-
super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, engine_id)
221+
super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port,
222+
engine_id)
216223
self.add_remote_agent_func = add_remote_agent_func
217-
self._tp_size_mapping: Dict[str, int] = {engine_id: tp_size}
218-
219-
def initiate_handshake(self, host: str, port: int,
220-
remote_tp_size: int) -> Dict[int, str]:
224+
self._tp_size_mapping: dict[str, int] = {engine_id: tp_size}
225+
226+
def initiate_handshake(self, host: str, port: int,
227+
remote_tp_size: int) -> dict[int, str]:
221228
start_time = time.perf_counter()
222229
logger.debug("Starting NIXL handshake with %s:%s", host, port)
223-
230+
224231
url = build_uri("http", host, port, path="get_kv_connector_metadata")
225-
232+
226233
try:
227234
req = URLRequest(url)
228-
with urlopen(req, timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response:
235+
with urlopen(req,
236+
timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response:
229237
response_data = response.read().decode('utf-8')
230238
res = json.loads(response_data)
231239
except (URLError, HTTPError) as e:
232240
logger.error("Failed to fetch metadata from %s: %s", url, e)
233241
raise
234-
242+
235243
if res is None:
236-
logger.warning("Remote server returned None metadata, skipping handshake")
244+
logger.warning(
245+
"Remote server returned None metadata, skipping handshake")
237246
raise RuntimeError("Remote server returned None metadata")
238-
247+
239248
# Get dp_rank 0 data (standard for disaggregated prefill-decode)
240249
dp_data = res.get("0", {})
241250
if not dp_data:
242251
raise RuntimeError("No metadata found for dp_rank 0")
243-
252+
244253
remote_tp_size = len(dp_data.keys())
245-
254+
246255
# Handshake only with the remote TP rank that current local rank will
247256
# pull from. With homogeneous TP it happens to be the same rank_i.
248257
tp_ratio = self._tp_size_mapping[self.engine_id] // remote_tp_size
249258
p_remote_rank = self.tp_rank // tp_ratio
250-
259+
251260
# Get data for the specific rank we need to connect to
252261
rank_data = dp_data.get(str(p_remote_rank), {})
253262
if not rank_data:
254-
raise RuntimeError(f"No metadata found for remote rank {p_remote_rank}")
255-
263+
raise RuntimeError(
264+
f"No metadata found for remote rank {p_remote_rank}")
265+
256266
metadata_bytes = rank_data.get("agent_metadata", None)
257267
if metadata_bytes is None:
258-
raise RuntimeError(f"No agent metadata found for remote rank {p_remote_rank}")
259-
268+
raise RuntimeError(
269+
f"No agent metadata found for remote rank {p_remote_rank}")
270+
260271
rank_data_copy = rank_data.copy()
261272
rank_data_copy.pop("agent_metadata", None)
262273
metadata = NixlAgentMetadata(
263274
agent_metadata=base64.b64decode(metadata_bytes), **rank_data_copy)
264-
275+
265276
pre_register = time.perf_counter()
266277
# Register Remote agent
267-
remote_agent_name = self.add_remote_agent_func(metadata, p_remote_rank, remote_tp_size)
278+
remote_agent_name = self.add_remote_agent_func(metadata, p_remote_rank,
279+
remote_tp_size)
268280
agent_time = time.perf_counter()
269-
270-
logger.debug("Finished registering remote agent for engine %s", metadata.engine_id)
271-
logger.debug("NIXL handshake: get metadata took: %s", pre_register - start_time)
272-
logger.debug("NIXL handshake: add agent took: %s", agent_time - pre_register)
273-
281+
282+
logger.debug("Finished registering remote agent for engine %s",
283+
metadata.engine_id)
284+
logger.debug("NIXL handshake: get metadata took: %s",
285+
pre_register - start_time)
286+
logger.debug("NIXL handshake: add agent took: %s",
287+
agent_time - pre_register)
288+
274289
logger.debug("NIXL handshake method completed for %s:%s", host, port)
275-
290+
276291
# Return remote rank -> agent name mapping
277292
return {p_remote_rank: remote_agent_name}
278-
293+
279294
def setup_listener(self, metadata: NixlAgentMetadata) -> None:
280295
pass
281-
296+
282297
def cleanup(self) -> None:
283298
pass
284299

@@ -680,8 +695,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
680695
self.side_channel_port, self.engine_id, self.add_remote_agent)
681696
else:
682697
raise ValueError(f"Unknown handshake method: {handshake_method}. "
683-
"Supported methods: 'zmq', 'http'")
684-
698+
"Supported methods: 'zmq', 'http'")
699+
685700
logger.info("Using %s handshake strategy", handshake_method)
686701

687702
def __del__(self):
@@ -693,7 +708,8 @@ def __del__(self):
693708

694709
def _nixl_handshake(self, host: str, port: int,
695710
remote_tp_size: int) -> dict[int, str]:
696-
return self._handshake_strategy.initiate_handshake(host, port, remote_tp_size)
711+
return self._handshake_strategy.initiate_handshake(
712+
host, port, remote_tp_size)
697713

698714
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
699715
"""Register the KV Cache data in nixl."""

0 commit comments

Comments
 (0)