14
14
15
15
from __future__ import annotations
16
16
17
+ import logging
17
18
import os
18
19
import re
19
20
import socket
20
21
import time
21
22
from jax ._src import clusters
22
23
from jax ._src .cloud_tpu_init import running_in_cloud_tpu_vm
23
24
25
+ logger = logging .getLogger (__name__ )
26
+
24
27
# We use an arbitrarily chosen port for the coordinator since we cannot
25
28
# rely on communication to choose one in real time.
26
29
coordinator_port = '8476'
27
30
31
+ metadata_response_code_success = 200
32
+
28
33
def get_metadata (key ):
29
34
import requests # pytype: disable=import-error
30
35
import time # pytype: disable=import-error
@@ -47,11 +52,11 @@ def get_metadata(key):
47
52
48
53
if api_resp is None :
49
54
raise RuntimeError (f"Getting metadata['{ key } '] failed for 6 tries" )
50
- return api_resp .text
55
+ return api_resp .text , api_resp . status_code
51
56
52
57
def get_tpu_env_value (key ):
53
58
def get_tpu_env_value_from_metadata (key ):
54
- tpu_env_data = get_metadata ('tpu-env' )
59
+ tpu_env_data = get_metadata ('tpu-env' )[ 0 ]
55
60
key_value_pairs = tpu_env_data .split ('\n ' )
56
61
for key_value_pair in key_value_pairs :
57
62
# Typical line is MEGASCALE_NUM_SLICES: '2'
@@ -65,54 +70,44 @@ def get_tpu_env_value_from_metadata(key):
65
70
value = os .environ .get (key , None )
66
71
return value if value is not None else get_tpu_env_value_from_metadata (key )
67
72
68
- def is_gce_env ():
69
- worker_number_string = get_metadata ('agent-worker-number' )
70
- try :
71
- worker_number = int (worker_number_string )
72
- return True
73
- except :
74
- return False
75
-
76
- def is_multislice_gce_env ():
77
- return is_gce_env () and get_tpu_env_value ('MEGASCALE_COORDINATOR_ADDRESS' ) is not None
78
-
79
- def is_gke_env ():
80
- return os .environ .get ("TPU_WORKER_HOSTNAMES" , None ) is not None
73
+ def has_megascale_address ():
74
+ return get_tpu_env_value ('MEGASCALE_COORDINATOR_ADDRESS' ) is not None
81
75
82
- def get_gce_worker_endpoints () -> str :
83
- return get_metadata ( 'worker-network-endpoints' ). split ( ',' )
76
+ class BaseTpuCluster ( clusters . ClusterEnv ) :
77
+ """Abstract cluster supports both single and multislice TPU environments.
84
78
85
- class SingleSliceGceTpuCluster (clusters .ClusterEnv ):
86
- @classmethod
87
- def is_env_present (cls ) -> bool :
88
- return running_in_cloud_tpu_vm and is_gce_env () and not is_multislice_gce_env ()
89
-
90
- @classmethod
91
- def get_coordinator_address (cls ) -> str :
92
- return f"{ get_gce_worker_endpoints ()[0 ].split (':' )[2 ]} :{ coordinator_port } "
93
-
94
- @classmethod
95
- def get_process_count (cls ) -> int :
96
- return len (get_gce_worker_endpoints ())
97
-
98
- @classmethod
99
- def get_process_id (cls ) -> int :
100
- return int (get_metadata ('agent-worker-number' ))
101
-
102
- @classmethod
103
- def get_local_process_id (cls ) -> int | None :
104
- return None
79
+ If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
80
+ Concrete extensions of this class must implement methods for generating a list
81
+ of within-slice workers and a within-slice process ID.
82
+ `get_coordinator_address` must return the address of the host with
83
+ process ID 0 (as returned by `get_process_id`), since the coordinator service
84
+ is started on the host with process ID = 0.
85
+ """
105
86
106
- class MultisliceGceTpuCluster (clusters .ClusterEnv ):
107
87
@classmethod
108
88
def is_env_present (cls ) -> bool :
109
- return running_in_cloud_tpu_vm and is_multislice_gce_env ()
89
+ """Override this method to return True if the environment is present."""
90
+ return False
110
91
111
92
@classmethod
112
93
def get_coordinator_address (cls ) -> str :
113
- coordinator_address = get_tpu_env_value ('MEGASCALE_COORDINATOR_ADDRESS' )
94
+ if has_megascale_address ():
95
+ # For both GCE via QueuedResources and GKE via JobSet, the
96
+ # Megascale coordinator address is set as the host with process id = 0,
97
+ # so can be used as the jax distributed system coordinator.
98
+ coordinator_address = get_tpu_env_value ('MEGASCALE_COORDINATOR_ADDRESS' )
99
+ else :
100
+ # For both GCE (QueuedResources and TPUVM create) and GKE via Job API,
101
+ # the workers lists are sorted by process ID so the first one can
102
+ # be used as the jax distributed system coordinator.
103
+ coordinator_address = cls ._get_worker_list_in_slice ()[0 ]
114
104
coordinator_address = coordinator_address .split (':' )[0 ]
105
+ logger .debug ("TPU Cluster using coordinator address: %s" , coordinator_address )
106
+ cls .wait_for_coordinator (coordinator_address )
107
+ return f'{ coordinator_address } :{ coordinator_port } '
115
108
109
+ @classmethod
110
+ def wait_for_coordinator (cls , coordinator_address ):
116
111
# The coordinator may not be up before the other hosts try to
117
112
# communicate with it. We check for its existence with retries.
118
113
coordinator_found = False
@@ -126,51 +121,92 @@ def get_coordinator_address(cls) -> str:
126
121
print (f"Failed to recognize coordinator address { coordinator_address } on attempt { lookup_attempt } , retrying..." )
127
122
lookup_attempt += 1
128
123
time .sleep (5 )
129
-
130
124
if not coordinator_found :
131
125
raise RuntimeError (f"Failed to recognize coordinator address { coordinator_address } " )
132
126
133
- # Use a different port for the jax coordinator than the MXLA coordinator,
134
- # which is set to 8080 in multislice GCE.
135
- return f'{ coordinator_address } :{ coordinator_port } '
136
-
137
127
@classmethod
138
128
def get_process_count (cls ) -> int :
139
- processes_per_slice = cls ._get_process_count_per_slice ()
140
- num_slices = int (get_tpu_env_value ('MEGASCALE_NUM_SLICES' ))
141
- return processes_per_slice * num_slices
129
+ processes_per_slice = len (cls ._get_worker_list_in_slice ())
130
+ num_slices = cls ._get_num_slices ()
131
+ total_process_count = processes_per_slice * num_slices
132
+ logger .debug ("Total process count of %s = %s processes per slice and %s slices" , total_process_count , processes_per_slice , num_slices )
133
+ return total_process_count
142
134
143
135
@classmethod
144
136
def get_process_id (cls ) -> int :
145
137
process_id_in_slice = cls ._get_process_id_in_slice ()
146
- slice_id = int (get_tpu_env_value ('MEGASCALE_SLICE_ID' ))
147
- processes_per_slice = cls ._get_process_count_per_slice ()
148
- return process_id_in_slice + slice_id * processes_per_slice
138
+ slice_id = cls ._get_slice_id ()
139
+ processes_per_slice = len (cls ._get_worker_list_in_slice ())
140
+ process_id = process_id_in_slice + slice_id * processes_per_slice
141
+ logger .debug ("Process ID of %s generated by within-slice id %s and slice id %s" , process_id , process_id_in_slice , slice_id )
142
+ return process_id
149
143
150
- @classmethod
151
- def get_local_process_id (cls ) -> int | None :
152
- return None
144
+ @staticmethod
145
+ def _get_num_slices () -> int :
146
+ if has_megascale_address ():
147
+ return int (get_tpu_env_value ('MEGASCALE_NUM_SLICES' ))
148
+ else :
149
+ return 1
153
150
154
151
@staticmethod
155
- def _get_process_count_per_slice () -> int :
156
- return len (get_gce_worker_endpoints ())
152
+ def _get_slice_id () -> int :
153
+ if has_megascale_address ():
154
+ return int (get_tpu_env_value ('MEGASCALE_SLICE_ID' ))
155
+ else :
156
+ return 0
157
157
158
158
@staticmethod
159
159
def _get_process_id_in_slice () -> int :
160
- return int (get_metadata ('agent-worker-number' ))
160
+ """Returns a process ID that is unique within slice."""
161
+ raise NotImplementedError ()
161
162
162
- class GkeTpuCluster (MultisliceGceTpuCluster ):
163
- # This class handles both single and multislice GKE as the environment
164
- # variables are set the same in both cases.
163
+ @staticmethod
164
+ def _get_worker_list_in_slice () -> list [str ]:
165
+ """Returns a list of worker endpoints/hostnames within slice."""
166
+ raise NotImplementedError ()
167
+
168
+ class GceTpuCluster (BaseTpuCluster ):
165
169
@classmethod
166
170
def is_env_present (cls ) -> bool :
167
- return running_in_cloud_tpu_vm and is_gke_env ()
171
+ if not running_in_cloud_tpu_vm :
172
+ logger .debug ("Did not detect cloud TPU VM" )
173
+ return False
174
+ metadata_response , metadata_code = get_metadata ('agent-worker-number' )
175
+ if metadata_code == metadata_response_code_success :
176
+ logger .debug ("Gce Tpu Cluster detected for Jax Distributed System" )
177
+ return True
178
+ else :
179
+ logger .debug ("Did not detect Gce Tpu Cluster since agent-worker-number is not set in metadata" )
180
+ logger .debug ("Metadata code: %s" , metadata_code )
181
+ logger .debug ("Metadata response: %s" , metadata_response )
182
+ return False
168
183
169
184
@staticmethod
170
- def _get_process_count_per_slice () -> int :
171
- tpu_worker_hostnames = str (os .environ .get ('TPU_WORKER_HOSTNAMES' , None ))
172
- return len (tpu_worker_hostnames .split (',' ))
185
+ def _get_process_id_in_slice () -> int :
186
+ return int (get_metadata ('agent-worker-number' )[0 ])
187
+
188
+ @staticmethod
189
+ def _get_worker_list_in_slice () -> list [str ]:
190
+ workers = get_metadata ('worker-network-endpoints' )[0 ].split (',' )
191
+ return [worker .split (':' )[2 ] for worker in workers ]
192
+
193
+ class GkeTpuCluster (BaseTpuCluster ):
194
+ @classmethod
195
+ def is_env_present (cls ) -> bool :
196
+ if running_in_cloud_tpu_vm and os .environ .get ("TPU_WORKER_HOSTNAMES" ) is not None :
197
+ logger .debug ("Gke Tpu Cluster detected for Jax Distributed System" )
198
+ return True
199
+ else :
200
+ if not running_in_cloud_tpu_vm :
201
+ logger .debug ("Did not detect cloud TPU VM" )
202
+ else :
203
+ logger .debug ("Did not detect TPU GKE cluster since TPU_WORKER_HOSTNAMES is not set" )
204
+ return False
173
205
174
206
@staticmethod
175
207
def _get_process_id_in_slice () -> int :
176
208
return int (str (os .environ .get ('TPU_WORKER_ID' )))
209
+
210
+ @staticmethod
211
+ def _get_worker_list_in_slice () -> list [str ]:
212
+ return str (os .environ .get ('TPU_WORKER_HOSTNAMES' , None )).split (',' )
0 commit comments