Skip to content

Commit 1dba2c4

Browse files
authored
[Misc] adjust for ipv6 for mookcacke url parse (#20107)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
1 parent 71d6de3 commit 1dba2c4

File tree

3 files changed

+99
-27
lines changed

3 files changed

+99
-27
lines changed

tests/test_utils.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
2121
MemorySnapshot, PlaceholderModule, StoreBoolean,
2222
bind_kv_cache, common_broadcastable_dtype,
23-
deprecate_kwargs, get_open_port, is_lossless_cast,
24-
make_zmq_path, make_zmq_socket, memory_profiling,
25-
merge_async_iterators, sha256, split_zmq_path,
26-
supports_kw, swap_dict_values)
23+
deprecate_kwargs, get_open_port, get_tcp_uri,
24+
is_lossless_cast, join_host_port, make_zmq_path,
25+
make_zmq_socket, memory_profiling,
26+
merge_async_iterators, sha256, split_host_port,
27+
split_zmq_path, supports_kw, swap_dict_values)
2728

2829
from .utils import create_new_process_for_each_test, error_on_warning
2930

@@ -876,3 +877,44 @@ def test_make_zmq_socket_ipv6():
876877
def test_make_zmq_path():
877878
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
878879
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
880+
881+
882+
def test_get_tcp_uri():
883+
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
884+
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
885+
886+
887+
def test_split_host_port():
888+
# valid ipv4
889+
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
890+
# invalid ipv4
891+
with pytest.raises(ValueError):
892+
# multi colon
893+
assert split_host_port("127.0.0.1::5555")
894+
with pytest.raises(ValueError):
895+
# tailing colon
896+
assert split_host_port("127.0.0.1:5555:")
897+
with pytest.raises(ValueError):
898+
# no colon
899+
assert split_host_port("127.0.0.15555")
900+
with pytest.raises(ValueError):
901+
# none int port
902+
assert split_host_port("127.0.0.1:5555a")
903+
904+
# valid ipv6
905+
assert split_host_port("[::1]:5555") == ("::1", 5555)
906+
# invalid ipv6
907+
with pytest.raises(ValueError):
908+
# multi colon
909+
assert split_host_port("[::1]::5555")
910+
with pytest.raises(IndexError):
911+
# no colon
912+
assert split_host_port("[::1]5555")
913+
with pytest.raises(ValueError):
914+
# none int port
915+
assert split_host_port("[::1]:5555a")
916+
917+
918+
def test_join_host_port():
919+
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
920+
assert join_host_port("::1", 5555) == "[::1]:5555"

vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.config import KVTransferConfig
1717
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
1818
from vllm.logger import init_logger
19+
from vllm.utils import join_host_port, make_zmq_path, split_host_port
1920

2021
logger = init_logger(__name__)
2122
NONE_INT = -150886311
@@ -79,18 +80,19 @@ def __init__(self, kv_rank: int, local_rank: int):
7980
logger.error(
8081
"An error occurred while loading the configuration: %s", exc)
8182
raise
82-
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
83-
decode_host, base_decode_port = self.config.decode_url.split(':')
83+
prefill_host, base_prefill_port = split_host_port(
84+
self.config.prefill_url)
85+
decode_host, base_decode_port = split_host_port(self.config.decode_url)
8486

8587
# Avoid ports conflict when running prefill and decode on the same node
8688
if prefill_host == decode_host and \
8789
base_prefill_port == base_decode_port:
88-
base_decode_port = str(int(base_decode_port) + 100)
90+
base_decode_port = base_decode_port + 100
8991

90-
prefill_port = int(base_prefill_port) + self.local_rank
91-
decode_port = int(base_decode_port) + self.local_rank
92-
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
93-
self.decode_url = ':'.join([decode_host, str(decode_port)])
92+
prefill_port = base_prefill_port + self.local_rank
93+
decode_port = base_decode_port + self.local_rank
94+
self.prefill_url = join_host_port(prefill_host, prefill_port)
95+
self.decode_url = join_host_port(decode_host, decode_port)
9496

9597
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
9698
self.config.metadata_server, self.config.protocol,
@@ -110,22 +112,30 @@ def __init__(self, kv_rank: int, local_rank: int):
110112
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
111113
decode_host, base_decode_port)
112114

113-
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
114-
d_host: str, d_port: str) -> None:
115+
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int,
116+
d_host: str, d_port: int) -> None:
115117
"""Set up ZeroMQ sockets for sending and receiving data."""
116118
# Offsets < 8 are left for initialization in case tp and pp are enabled
117-
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
118-
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
119+
p_rank_offset = p_port + 8 + self.local_rank * 2
120+
d_rank_offset = d_port + 8 + self.local_rank * 2
119121
if kv_rank == 0:
120-
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
121-
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
122-
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
123-
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
122+
self.sender_socket.bind(
123+
make_zmq_path("tcp", p_host, p_rank_offset + 1))
124+
self.receiver_socket.connect(
125+
make_zmq_path("tcp", d_host, d_rank_offset + 1))
126+
self.sender_ack.connect(
127+
make_zmq_path("tcp", d_host, d_rank_offset + 2))
128+
self.receiver_ack.bind(
129+
make_zmq_path("tcp", p_host, p_rank_offset + 2))
124130
else:
125-
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
126-
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
127-
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
128-
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
131+
self.receiver_socket.connect(
132+
make_zmq_path("tcp", p_host, p_rank_offset + 1))
133+
self.sender_socket.bind(
134+
make_zmq_path("tcp", d_host, d_rank_offset + 1))
135+
self.receiver_ack.bind(
136+
make_zmq_path("tcp", d_host, d_rank_offset + 2))
137+
self.sender_ack.connect(
138+
make_zmq_path("tcp", p_host, p_rank_offset + 2))
129139

130140
def initialize(self, local_hostname: str, metadata_server: str,
131141
protocol: str, device_name: str,

vllm/utils/__init__.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from functools import cache, lru_cache, partial, wraps
4747
from types import MappingProxyType
4848
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
49-
Optional, TypeVar, Union, cast, overload)
49+
Optional, Tuple, TypeVar, Union, cast, overload)
5050
from urllib.parse import urlparse
5151
from uuid import uuid4
5252

@@ -628,14 +628,34 @@ def is_valid_ipv6_address(address: str) -> bool:
628628
return False
629629

630630

631+
def split_host_port(host_port: str) -> Tuple[str, int]:
632+
# ipv6
633+
if host_port.startswith('['):
634+
host, port = host_port.rsplit(']', 1)
635+
host = host[1:]
636+
port = port.split(':')[1]
637+
return host, int(port)
638+
else:
639+
host, port = host_port.split(':')
640+
return host, int(port)
641+
642+
643+
def join_host_port(host: str, port: int) -> str:
644+
if is_valid_ipv6_address(host):
645+
return f"[{host}]:{port}"
646+
else:
647+
return f"{host}:{port}"
648+
649+
631650
def get_distributed_init_method(ip: str, port: int) -> str:
632651
return get_tcp_uri(ip, port)
633652

634653

635654
def get_tcp_uri(ip: str, port: int) -> str:
636-
# Brackets are not permitted in ipv4 addresses,
637-
# see https://github.com/python/cpython/issues/103848
638-
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
655+
if is_valid_ipv6_address(ip):
656+
return f"tcp://[{ip}]:{port}"
657+
else:
658+
return f"tcp://{ip}:{port}"
639659

640660

641661
def get_open_zmq_ipc_path() -> str:

0 commit comments

Comments
 (0)