16
16
from vllm .config import KVTransferConfig
17
17
from vllm .distributed .kv_transfer .kv_pipe .base import KVPipeBase
18
18
from vllm .logger import init_logger
19
+ from vllm .utils import join_host_port , make_zmq_path , split_host_port
19
20
20
21
logger = init_logger (__name__ )
21
22
NONE_INT = - 150886311
@@ -79,18 +80,19 @@ def __init__(self, kv_rank: int, local_rank: int):
79
80
logger .error (
80
81
"An error occurred while loading the configuration: %s" , exc )
81
82
raise
82
- prefill_host , base_prefill_port = self .config .prefill_url .split (':' )
83
- decode_host , base_decode_port = self .config .decode_url .split (':' )
83
+ prefill_host , base_prefill_port = split_host_port (
84
+ self .config .prefill_url )
85
+ decode_host , base_decode_port = split_host_port (self .config .decode_url )
84
86
85
87
# Avoid ports conflict when running prefill and decode on the same node
86
88
if prefill_host == decode_host and \
87
89
base_prefill_port == base_decode_port :
88
- base_decode_port = str ( int ( base_decode_port ) + 100 )
90
+ base_decode_port = base_decode_port + 100
89
91
90
- prefill_port = int ( base_prefill_port ) + self .local_rank
91
- decode_port = int ( base_decode_port ) + self .local_rank
92
- self .prefill_url = ':' . join ([ prefill_host , str ( prefill_port )] )
93
- self .decode_url = ':' . join ([ decode_host , str ( decode_port )] )
92
+ prefill_port = base_prefill_port + self .local_rank
93
+ decode_port = base_decode_port + self .local_rank
94
+ self .prefill_url = join_host_port ( prefill_host , prefill_port )
95
+ self .decode_url = join_host_port ( decode_host , decode_port )
94
96
95
97
self .initialize (self .prefill_url if kv_rank == 0 else self .decode_url ,
96
98
self .config .metadata_server , self .config .protocol ,
@@ -110,22 +112,30 @@ def __init__(self, kv_rank: int, local_rank: int):
110
112
self ._setup_metadata_sockets (kv_rank , prefill_host , base_prefill_port ,
111
113
decode_host , base_decode_port )
112
114
113
- def _setup_metadata_sockets (self , kv_rank : int , p_host : str , p_port : str ,
114
- d_host : str , d_port : str ) -> None :
115
+ def _setup_metadata_sockets (self , kv_rank : int , p_host : str , p_port : int ,
116
+ d_host : str , d_port : int ) -> None :
115
117
"""Set up ZeroMQ sockets for sending and receiving data."""
116
118
# Offsets < 8 are left for initialization in case tp and pp are enabled
117
- p_rank_offset = int ( p_port ) + 8 + self .local_rank * 2
118
- d_rank_offset = int ( d_port ) + 8 + self .local_rank * 2
119
+ p_rank_offset = p_port + 8 + self .local_rank * 2
120
+ d_rank_offset = d_port + 8 + self .local_rank * 2
119
121
if kv_rank == 0 :
120
- self .sender_socket .bind (f"tcp://{ p_host } :{ p_rank_offset + 1 } " )
121
- self .receiver_socket .connect (f"tcp://{ d_host } :{ d_rank_offset + 1 } " )
122
- self .sender_ack .connect (f"tcp://{ d_host } :{ d_rank_offset + 2 } " )
123
- self .receiver_ack .bind (f"tcp://{ p_host } :{ p_rank_offset + 2 } " )
122
+ self .sender_socket .bind (
123
+ make_zmq_path ("tcp" , p_host , p_rank_offset + 1 ))
124
+ self .receiver_socket .connect (
125
+ make_zmq_path ("tcp" , d_host , d_rank_offset + 1 ))
126
+ self .sender_ack .connect (
127
+ make_zmq_path ("tcp" , d_host , d_rank_offset + 2 ))
128
+ self .receiver_ack .bind (
129
+ make_zmq_path ("tcp" , p_host , p_rank_offset + 2 ))
124
130
else :
125
- self .receiver_socket .connect (f"tcp://{ p_host } :{ p_rank_offset + 1 } " )
126
- self .sender_socket .bind (f"tcp://{ d_host } :{ d_rank_offset + 1 } " )
127
- self .receiver_ack .bind (f"tcp://{ d_host } :{ d_rank_offset + 2 } " )
128
- self .sender_ack .connect (f"tcp://{ p_host } :{ p_rank_offset + 2 } " )
131
+ self .receiver_socket .connect (
132
+ make_zmq_path ("tcp" , p_host , p_rank_offset + 1 ))
133
+ self .sender_socket .bind (
134
+ make_zmq_path ("tcp" , d_host , d_rank_offset + 1 ))
135
+ self .receiver_ack .bind (
136
+ make_zmq_path ("tcp" , d_host , d_rank_offset + 2 ))
137
+ self .sender_ack .connect (
138
+ make_zmq_path ("tcp" , p_host , p_rank_offset + 2 ))
129
139
130
140
def initialize (self , local_hostname : str , metadata_server : str ,
131
141
protocol : str , device_name : str ,
0 commit comments