@@ -488,8 +488,13 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
488
488
"Connection listener got unexpected message %s" , msg )
489
489
sock .send_multipart ((identity , b"" , encoded_data ))
490
490
491
- def _nixl_handshake (self , host : str , port : int ,
492
- remote_tp_size : int ) -> dict [int , str ]:
491
+ def _nixl_handshake (
492
+ self ,
493
+ host : str ,
494
+ port : int ,
495
+ remote_tp_size : int ,
496
+ expected_engine_id : str ,
497
+ ) -> dict [int , str ]:
493
498
"""Do a NIXL handshake with a remote instance."""
494
499
495
500
start_time = time .perf_counter ()
@@ -498,35 +503,39 @@ def _nixl_handshake(self, host: str, port: int,
498
503
# a hack to keep us moving. We will switch when moving to etcd
499
504
# or where we have a single ZMQ socket in the scheduler.
500
505
501
- def handshake (path : str , rank : int ) -> str :
502
- # Send query for the request.
503
- with zmq_ctx (zmq .REQ , path ) as sock :
504
- sock .send (GET_META_MSG )
505
- metadata_bytes = sock .recv ()
506
- decoder = msgspec .msgpack .Decoder (NixlAgentMetadata )
507
- metadata = decoder .decode (metadata_bytes )
508
- got_metadata_time = time .perf_counter ()
509
-
510
- # Register Remote agent.
511
- remote_agent_name = self .add_remote_agent (
512
- metadata , rank , remote_tp_size )
513
- setup_agent_time = time .perf_counter ()
514
-
515
- logger .debug ("NIXL handshake: get metadata took: %s" ,
516
- got_metadata_time - start_time )
517
- logger .debug ("NIXL handshake: add agent took: %s" ,
518
- setup_agent_time - got_metadata_time )
519
- return remote_agent_name
520
-
521
506
# Handshake only with the remote TP rank that current local rank will
522
507
# pull from. With homogeneous TP it happens to be the same rank_i.
523
508
tp_ratio = self ._tp_size [self .engine_id ] // remote_tp_size
524
509
p_remote_rank = self .tp_rank // tp_ratio
525
510
path = make_zmq_path ("tcp" , host , port + p_remote_rank )
526
511
logger .debug ("Querying metadata on path: %s at remote rank %s" , path ,
527
512
p_remote_rank )
513
+
514
+ # Send query for the request.
515
+ with zmq_ctx (zmq .REQ , path ) as sock :
516
+ sock .send (GET_META_MSG )
517
+ metadata_bytes = sock .recv ()
518
+ decoder = msgspec .msgpack .Decoder (NixlAgentMetadata )
519
+ metadata = decoder .decode (metadata_bytes )
520
+ got_metadata_time = time .perf_counter ()
521
+ logger .debug ("NIXL handshake: get metadata took: %s" ,
522
+ got_metadata_time - start_time )
523
+
524
+ # Ensure engine id matches.
525
+ if metadata .engine_id != expected_engine_id :
526
+ raise RuntimeError (f"Remote NIXL agent engine ID mismatch. "
527
+ f"Expected { expected_engine_id } ,"
528
+ f"received { metadata .engine_id } ." )
529
+
530
+ # Register Remote agent.
531
+ remote_agent_name = self .add_remote_agent (metadata , p_remote_rank ,
532
+ remote_tp_size )
533
+ setup_agent_time = time .perf_counter ()
534
+ logger .debug ("NIXL handshake: add agent took: %s" ,
535
+ setup_agent_time - got_metadata_time )
536
+
528
537
# Remote rank -> agent name.
529
- return {p_remote_rank : handshake ( path , p_remote_rank ) }
538
+ return {p_remote_rank : remote_agent_name }
530
539
531
540
def _background_nixl_handshake (self , req_id : str ,
532
541
remote_engine_id : EngineId , meta : ReqMeta ):
@@ -535,7 +544,7 @@ def _background_nixl_handshake(self, req_id: str,
535
544
if fut is None :
536
545
fut = self ._handshake_initiation_executor .submit (
537
546
self ._nixl_handshake , meta .remote_host , meta .remote_port ,
538
- meta .tp_size )
547
+ meta .tp_size , remote_engine_id )
539
548
self ._handshake_futures [remote_engine_id ] = fut
540
549
541
550
def done_callback (f : Future [dict [int , str ]], eid = remote_engine_id ):
@@ -738,10 +747,10 @@ def add_remote_agent(self,
738
747
if remote_tp_rank in self ._remote_agents .get (engine_id , {}):
739
748
return self ._remote_agents [engine_id ][remote_tp_rank ]
740
749
741
- if engine_id in self ._tp_size :
742
- assert self ._tp_size [engine_id ] == remote_tp_size
743
- else :
750
+ if engine_id not in self ._tp_size :
744
751
self ._tp_size [engine_id ] = remote_tp_size
752
+ else :
753
+ assert self ._tp_size [engine_id ] == remote_tp_size
745
754
# We may eventually enable this after asserting equality in cache
746
755
# layout and close outputs.
747
756
assert nixl_agent_meta .attn_backend_name == self .backend_name
0 commit comments