Skip to content

Commit 4ab3ac2

Browse files
authored
[Bugfix] Fix flaky failure when getting DP ports (#20151)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent d1c956d commit 4ab3ac2

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

vllm/config.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,18 +1878,41 @@ def get_next_dp_init_port(self) -> int:
18781878
return answer
18791879

18801880
def stateless_init_dp_group(self) -> "ProcessGroup":
1881+
# NOTE: In high-concurrency scenarios multiple processes
1882+
# can pick the same (currently free) port through a race
1883+
# condition when calling `get_open_port()`. When the first
1884+
# process binds the port the others will subsequently fail
1885+
# with `torch.distributed.DistNetworkError: EADDRINUSE`.
1886+
# To make the initialization more robust we retry a few times
1887+
# with a fresh port whenever this specific error is observed.
1888+
from torch.distributed import DistNetworkError
1889+
18811890
from vllm.distributed.utils import (
18821891
stateless_init_torch_distributed_process_group)
18831892

1884-
# use gloo since the engine process might not have cuda device
1885-
dp_group = stateless_init_torch_distributed_process_group(
1886-
self.data_parallel_master_ip,
1887-
self.get_next_dp_init_port(),
1888-
self.data_parallel_rank,
1889-
self.data_parallel_size,
1890-
backend="gloo")
1891-
1892-
return dp_group
1893+
max_retries = 5
1894+
last_exc: Optional[Exception] = None
1895+
for _ in range(max_retries):
1896+
try:
1897+
# use gloo since the engine process might not have cuda device
1898+
return stateless_init_torch_distributed_process_group(
1899+
self.data_parallel_master_ip,
1900+
self.get_next_dp_init_port(),
1901+
self.data_parallel_rank,
1902+
self.data_parallel_size,
1903+
backend="gloo")
1904+
except DistNetworkError as e:
1905+
# We only want to retry when the root cause is EADDRINUSE.
1906+
if "EADDRINUSE" in str(e):
1907+
logger.warning(
1908+
"Address already in use. Retrying with a new port.")
1909+
last_exc = e
1910+
continue # try again with a new port
1911+
raise e
1912+
1913+
# If we get here all retries have failed.
1914+
assert last_exc is not None
1915+
raise last_exc
18931916

18941917
@staticmethod
18951918
def has_unfinished_dp(dp_group: "ProcessGroup",

0 commit comments

Comments
 (0)