Skip to content

Commit a53e99a

Browse files
author
jax authors
committed
* Support automating single slice GKE TPU cluster via Job API.
* Refactor the GCE and GKE clusters to inherit from a shared base since they supply the same information * This shared base supports both multislice and single slice PiperOrigin-RevId: 616440923
1 parent 8f4658e commit a53e99a

File tree

2 files changed

+101
-66
lines changed

2 files changed

+101
-66
lines changed

jax/_src/clusters/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,4 @@
2323
from .ompi_cluster import OmpiCluster
2424
from .slurm_cluster import SlurmCluster
2525
from .cloud_tpu_cluster import GkeTpuCluster
26-
from .cloud_tpu_cluster import MultisliceGceTpuCluster
27-
from .cloud_tpu_cluster import SingleSliceGceTpuCluster
26+
from .cloud_tpu_cluster import GceTpuCluster

jax/_src/clusters/cloud_tpu_cluster.py

Lines changed: 100 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,22 @@
1414

1515
from __future__ import annotations
1616

17+
import logging
1718
import os
1819
import re
1920
import socket
2021
import time
2122
from jax._src import clusters
2223
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
2324

25+
logger = logging.getLogger(__name__)
26+
2427
# We use an arbitrarily chosen port for the coordinator since we cannot
2528
# rely on communication to choose one in real time.
2629
coordinator_port = '8476'
2730

31+
metadata_response_code_success = 200
32+
2833
def get_metadata(key):
2934
import requests # pytype: disable=import-error
3035
import time # pytype: disable=import-error
@@ -47,11 +52,11 @@ def get_metadata(key):
4752

4853
if api_resp is None:
4954
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
50-
return api_resp.text
55+
return api_resp.text, api_resp.status_code
5156

5257
def get_tpu_env_value(key):
5358
def get_tpu_env_value_from_metadata(key):
54-
tpu_env_data = get_metadata('tpu-env')
59+
tpu_env_data = get_metadata('tpu-env')[0]
5560
key_value_pairs = tpu_env_data.split('\n')
5661
for key_value_pair in key_value_pairs:
5762
# Typical line is MEGASCALE_NUM_SLICES: '2'
@@ -65,54 +70,44 @@ def get_tpu_env_value_from_metadata(key):
6570
value = os.environ.get(key, None)
6671
return value if value is not None else get_tpu_env_value_from_metadata(key)
6772

68-
def is_gce_env():
69-
worker_number_string = get_metadata('agent-worker-number')
70-
try:
71-
worker_number = int(worker_number_string)
72-
return True
73-
except:
74-
return False
75-
76-
def is_multislice_gce_env():
77-
return is_gce_env() and get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
78-
79-
def is_gke_env():
80-
return os.environ.get("TPU_WORKER_HOSTNAMES", None) is not None
73+
def has_megascale_address():
74+
return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
8175

82-
def get_gce_worker_endpoints() -> str:
83-
return get_metadata('worker-network-endpoints').split(',')
76+
class BaseTpuCluster(clusters.ClusterEnv):
77+
"""Abstract cluster supports both single and multislice TPU environments.
8478
85-
class SingleSliceGceTpuCluster(clusters.ClusterEnv):
86-
@classmethod
87-
def is_env_present(cls) -> bool:
88-
return running_in_cloud_tpu_vm and is_gce_env() and not is_multislice_gce_env()
89-
90-
@classmethod
91-
def get_coordinator_address(cls) -> str:
92-
return f"{get_gce_worker_endpoints()[0].split(':')[2]}:{coordinator_port}"
93-
94-
@classmethod
95-
def get_process_count(cls) -> int:
96-
return len(get_gce_worker_endpoints())
97-
98-
@classmethod
99-
def get_process_id(cls) -> int:
100-
return int(get_metadata('agent-worker-number'))
101-
102-
@classmethod
103-
def get_local_process_id(cls) -> int | None:
104-
return None
79+
If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
80+
Concrete extensions of this class must implement methods for generating a list
81+
of within-slice workers and a within-slice process ID.
82+
`get_coordinator_address` must return the address of the host with
83+
process ID 0 (as returned by `get_process_id`), since the coordinator service
84+
is started on the host with process ID = 0.
85+
"""
10586

106-
class MultisliceGceTpuCluster(clusters.ClusterEnv):
10787
@classmethod
10888
def is_env_present(cls) -> bool:
109-
return running_in_cloud_tpu_vm and is_multislice_gce_env()
89+
"""Override this method to return True if the environment is present."""
90+
return False
11091

11192
@classmethod
11293
def get_coordinator_address(cls) -> str:
113-
coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS')
94+
if has_megascale_address():
95+
# For both GCE via QueuedResources and GKE via JobSet, the
96+
# Megascale coordinator address is set as the host with process id = 0,
97+
# so can be used as the jax distributed system coordinator.
98+
coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS')
99+
else:
100+
# For both GCE (QueuedResources and TPUVM create) and GKE via Job API,
101+
# the workers lists are sorted by process ID so the first one can
102+
# be used as the jax distributed system coordinator.
103+
coordinator_address = cls._get_worker_list_in_slice()[0]
114104
coordinator_address = coordinator_address.split(':')[0]
105+
logger.debug("TPU Cluster using coordinator address: %s", coordinator_address)
106+
cls.wait_for_coordinator(coordinator_address)
107+
return f'{coordinator_address}:{coordinator_port}'
115108

109+
@classmethod
110+
def wait_for_coordinator(cls, coordinator_address):
116111
# The coordinator may not be up before the other hosts try to
117112
# communicate with it. We check for its existence with retries.
118113
coordinator_found = False
@@ -126,51 +121,92 @@ def get_coordinator_address(cls) -> str:
126121
print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...")
127122
lookup_attempt += 1
128123
time.sleep(5)
129-
130124
if not coordinator_found:
131125
raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}")
132126

133-
# Use a different port for the jax coordinator than the MXLA coordinator,
134-
# which is set to 8080 in multislice GCE.
135-
return f'{coordinator_address}:{coordinator_port}'
136-
137127
@classmethod
138128
def get_process_count(cls) -> int:
139-
processes_per_slice = cls._get_process_count_per_slice()
140-
num_slices = int(get_tpu_env_value('MEGASCALE_NUM_SLICES'))
141-
return processes_per_slice * num_slices
129+
processes_per_slice = len(cls._get_worker_list_in_slice())
130+
num_slices = cls._get_num_slices()
131+
total_process_count = processes_per_slice * num_slices
132+
logger.debug("Total process count of %s = %s processes per slice and %s slices", total_process_count, processes_per_slice, num_slices)
133+
return total_process_count
142134

143135
@classmethod
144136
def get_process_id(cls) -> int:
145137
process_id_in_slice = cls._get_process_id_in_slice()
146-
slice_id = int(get_tpu_env_value('MEGASCALE_SLICE_ID'))
147-
processes_per_slice = cls._get_process_count_per_slice()
148-
return process_id_in_slice + slice_id * processes_per_slice
138+
slice_id = cls._get_slice_id()
139+
processes_per_slice = len(cls._get_worker_list_in_slice())
140+
process_id = process_id_in_slice + slice_id * processes_per_slice
141+
logger.debug("Process ID of %s generated by within-slice id %s and slice id %s", process_id, process_id_in_slice, slice_id)
142+
return process_id
149143

150-
@classmethod
151-
def get_local_process_id(cls) -> int | None:
152-
return None
144+
@staticmethod
145+
def _get_num_slices() -> int:
146+
if has_megascale_address():
147+
return int(get_tpu_env_value('MEGASCALE_NUM_SLICES'))
148+
else:
149+
return 1
153150

154151
@staticmethod
155-
def _get_process_count_per_slice() -> int:
156-
return len(get_gce_worker_endpoints())
152+
def _get_slice_id() -> int:
153+
if has_megascale_address():
154+
return int(get_tpu_env_value('MEGASCALE_SLICE_ID'))
155+
else:
156+
return 0
157157

158158
@staticmethod
159159
def _get_process_id_in_slice() -> int:
160-
return int(get_metadata('agent-worker-number'))
160+
"""Returns a process ID that is unique within slice."""
161+
raise NotImplementedError()
161162

162-
class GkeTpuCluster(MultisliceGceTpuCluster):
163-
# This class handles both single and multislice GKE as the environment
164-
# variables are set the same in both cases.
163+
@staticmethod
164+
def _get_worker_list_in_slice() -> list[str]:
165+
"""Returns a list of worker endpoints/hostnames within slice."""
166+
raise NotImplementedError()
167+
168+
class GceTpuCluster(BaseTpuCluster):
165169
@classmethod
166170
def is_env_present(cls) -> bool:
167-
return running_in_cloud_tpu_vm and is_gke_env()
171+
if not running_in_cloud_tpu_vm:
172+
logger.debug("Did not detect cloud TPU VM")
173+
return False
174+
metadata_response, metadata_code = get_metadata('agent-worker-number')
175+
if metadata_code == metadata_response_code_success:
176+
logger.debug("Gce Tpu Cluster detected for Jax Distributed System")
177+
return True
178+
else:
179+
logger.debug("Did not detect Gce Tpu Cluster since agent-worker-number is not set in metadata")
180+
logger.debug("Metadata code: %s", metadata_code)
181+
logger.debug("Metadata response: %s", metadata_response)
182+
return False
168183

169184
@staticmethod
170-
def _get_process_count_per_slice() -> int:
171-
tpu_worker_hostnames = str(os.environ.get('TPU_WORKER_HOSTNAMES', None))
172-
return len(tpu_worker_hostnames.split(','))
185+
def _get_process_id_in_slice() -> int:
186+
return int(get_metadata('agent-worker-number')[0])
187+
188+
@staticmethod
189+
def _get_worker_list_in_slice() -> list[str]:
190+
workers = get_metadata('worker-network-endpoints')[0].split(',')
191+
return [worker.split(':')[2] for worker in workers]
192+
193+
class GkeTpuCluster(BaseTpuCluster):
194+
@classmethod
195+
def is_env_present(cls) -> bool:
196+
if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None:
197+
logger.debug("Gke Tpu Cluster detected for Jax Distributed System")
198+
return True
199+
else:
200+
if not running_in_cloud_tpu_vm:
201+
logger.debug("Did not detect cloud TPU VM")
202+
else:
203+
logger.debug("Did not detect TPU GKE cluster since TPU_WORKER_HOSTNAMES is not set")
204+
return False
173205

174206
@staticmethod
175207
def _get_process_id_in_slice() -> int:
176208
return int(str(os.environ.get('TPU_WORKER_ID')))
209+
210+
@staticmethod
211+
def _get_worker_list_in_slice() -> list[str]:
212+
return str(os.environ.get('TPU_WORKER_HOSTNAMES', None)).split(',')

0 commit comments

Comments
 (0)