Skip to content

Commit c5869fe

Browse files
author
jax authors
committed
Add option to set coordinator lookup timeout for TPU clusters
PiperOrigin-RevId: 617383458
1 parent 33e1a96 commit c5869fe

File tree

5 files changed

+23
-15
lines changed

5 files changed

+23
-15
lines changed

jax/_src/clusters/cloud_tpu_cluster.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def is_env_present(cls) -> bool:
9090
return False
9191

9292
@classmethod
93-
def get_coordinator_address(cls) -> str:
93+
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
9494
if has_megascale_address():
9595
# For both GCE via QueuedResources and GKE via JobSet, the
9696
# Megascale coordinator address is set as the host with process id = 0,
@@ -103,24 +103,27 @@ def get_coordinator_address(cls) -> str:
103103
coordinator_address = cls._get_worker_list_in_slice()[0]
104104
coordinator_address = coordinator_address.split(':')[0]
105105
logger.debug("TPU Cluster using coordinator address: %s", coordinator_address)
106-
cls.wait_for_coordinator(coordinator_address)
106+
cls.wait_for_coordinator(coordinator_address, timeout_secs)
107107
return f'{coordinator_address}:{coordinator_port}'
108108

109109
@classmethod
110-
def wait_for_coordinator(cls, coordinator_address):
110+
def wait_for_coordinator(cls, coordinator_address, timeout_secs):
111111
# The coordinator may not be up before the other hosts try to
112112
# communicate with it. We check for its existence with retries.
113113
coordinator_found = False
114-
lookup_attempt = 1
115-
max_coordinator_lookups = 50
116-
while not coordinator_found and lookup_attempt <= max_coordinator_lookups:
114+
max_time = time.time() + timeout_secs
115+
coordinator_retry_secs = 5
116+
while not coordinator_found and time.time() < max_time:
117117
try:
118118
ip_address = socket.gethostbyname(coordinator_address)
119119
coordinator_found = True
120+
logger.debug("Found coordinator with address %s", coordinator_address)
120121
except socket.gaierror:
121-
print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...")
122-
lookup_attempt += 1
123-
time.sleep(5)
122+
logger.debug(
123+
"Failed to recognize coordinator address %s"
124+
" retrying...", coordinator_address
125+
)
126+
time.sleep(coordinator_retry_secs)
124127
if not coordinator_found:
125128
raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}")
126129

jax/_src/clusters/cluster.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def auto_detect_unset_distributed_params(cls,
4242
coordinator_address: str | None,
4343
num_processes: int | None,
4444
process_id: int | None,
45-
local_device_ids: Sequence[int] | None
45+
local_device_ids: Sequence[int] | None,
46+
initialization_timeout: int | None,
4647
) -> tuple[str | None, int | None, int | None,
4748
Sequence[int] | None]:
4849
if all(p is not None for p in (coordinator_address, num_processes,
@@ -53,7 +54,7 @@ def auto_detect_unset_distributed_params(cls,
5354
if env:
5455
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
5556
if coordinator_address is None:
56-
coordinator_address = env.get_coordinator_address()
57+
coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout)
5758
if num_processes is None:
5859
num_processes = env.get_process_count()
5960
if process_id is None:
@@ -79,7 +80,7 @@ def is_env_present(cls) -> bool:
7980
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")
8081

8182
@classmethod
82-
def get_coordinator_address(cls) -> str:
83+
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
8384
"""Returns address and port used by JAX to bootstrap.
8485
8586
Process id 0 will open a tcp socket at "hostname:port" where

jax/_src/clusters/ompi_cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def is_env_present(cls) -> bool:
3030
return _ORTE_URI in os.environ
3131

3232
@classmethod
33-
def get_coordinator_address(cls) -> str:
33+
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
3434
# Examples of orte_uri:
3535
# 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911
3636
# 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370

jax/_src/clusters/slurm_cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def is_env_present(cls) -> bool:
3030
return _JOBID_PARAM in os.environ
3131

3232
@classmethod
33-
def get_coordinator_address(cls) -> str:
33+
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
3434
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
3535
port = int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1)
3636

jax/_src/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def initialize(self,
4949

5050
(coordinator_address, num_processes, process_id, local_device_ids) = (
5151
clusters.ClusterEnv.auto_detect_unset_distributed_params(
52-
coordinator_address, num_processes, process_id, local_device_ids
52+
coordinator_address,
53+
num_processes,
54+
process_id,
55+
local_device_ids,
56+
initialization_timeout,
5357
)
5458
)
5559

0 commit comments

Comments
 (0)