Skip to content

Commit 3e03ddf

Browse files
authored
Fix the logic to calculate the number of workers based on the TPU version. (#51227)
The calculation of number of workers were incorrect that it didn't take the correct number of cores/chip into the calculation. --------- Signed-off-by: Quinn <qinyiyan@google.com>
1 parent 1d1b1b0 commit 3e03ddf

File tree

3 files changed

+66
-41
lines changed

3 files changed

+66
-41
lines changed

python/ray/_private/accelerators/tpu.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,22 @@
4343
TPU_HOST_BOUNDS_ENV_VAR = "TPU_HOST_BOUNDS"
4444
TPU_SINGLE_HOST_BOUNDS = "1,1,1"
4545

46+
# By default TPU VMs come with 4 chips per host and 2 tensorcores per chip.
47+
# For more details: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm
48+
DEFAULT_TPU_NUM_CHIPS_PER_HOST = 4
49+
DEFAULT_TPU_NUM_CORES_PER_CHIP = 2
50+
51+
# Accelerators that are 4 chips per host: v2, v3, v4, v5p
52+
# Accelerators that are 8 chips per host: v5e, v6e
53+
SINGLE_HOST_8_CHIPS_TPU_TYPES = ("v5litepod", "v6e")
54+
55+
# Accelerators that are 2 cores per chip: v2, v3, v4, v5p
56+
# Accelerators that are 1 core per chip: v5e, v6e
57+
SINGLE_CORE_TPU_TYPES = ("v5litepod", "v6e")
58+
59+
# The valid TPU types.
60+
VALID_TPU_TYPES = ("v2", "v3", "v4", "v5p", "v5litepod", "v6e")
61+
4662

4763
def _get_tpu_metadata(key: str) -> Optional[str]:
4864
"""Poll and get TPU metadata."""
@@ -67,6 +83,29 @@ def _get_tpu_metadata(key: str) -> Optional[str]:
6783
return None
6884

6985

86+
def _accelerator_type_check(accelerator_type: str):
87+
if not accelerator_type.startswith(VALID_TPU_TYPES):
88+
raise ValueError(
89+
f"Invalid accelerator type: {accelerator_type}. Must start with one of: {VALID_TPU_TYPES}"
90+
)
91+
92+
93+
def get_num_tpu_visible_chips_per_host(accelerator_type: str) -> int:
94+
_accelerator_type_check(accelerator_type)
95+
if accelerator_type.startswith(SINGLE_HOST_8_CHIPS_TPU_TYPES):
96+
return 8
97+
98+
return DEFAULT_TPU_NUM_CHIPS_PER_HOST
99+
100+
101+
def get_tpu_cores_per_chip(accelerator_type: str) -> int:
102+
_accelerator_type_check(accelerator_type)
103+
if accelerator_type.startswith(SINGLE_CORE_TPU_TYPES):
104+
return 1
105+
106+
return DEFAULT_TPU_NUM_CORES_PER_CHIP
107+
108+
70109
class TPUAcceleratorManager(AcceleratorManager):
71110
"""Google TPU accelerators."""
72111

@@ -273,10 +312,16 @@ def _get_current_node_tpu_worker_id() -> Optional[int]:
273312
def get_num_workers_in_current_tpu_pod() -> Optional[int]:
274313
"""Return the total number of workers in a TPU pod."""
275314
tpu_pod_type = TPUAcceleratorManager._get_current_node_tpu_pod_type()
276-
cores_per_host = TPUAcceleratorManager.get_current_node_num_accelerators()
315+
chips_per_host = TPUAcceleratorManager.get_current_node_num_accelerators()
316+
cores_per_chip = get_tpu_cores_per_chip(tpu_pod_type) # Hard-coded map.
317+
cores_per_host = chips_per_host * cores_per_chip
277318
if tpu_pod_type and cores_per_host > 0:
278-
num_chips_or_cores = int(tpu_pod_type.split("-")[1])
279-
return num_chips_or_cores // cores_per_host
319+
num_cores = int(tpu_pod_type.split("-")[1])
320+
num_workers = num_cores // cores_per_host
321+
# If the chip count doesn't fill a full host, a sub-host is still treated as a host.
322+
if num_cores % cores_per_host != 0:
323+
num_workers += 1
324+
return num_workers
280325
else:
281326
logging.debug("Could not get num workers in TPU pod.")
282327
return None

python/ray/autoscaler/_private/gcp/config.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from googleapiclient import discovery, errors
1818

1919
from ray._private.accelerators import TPUAcceleratorManager
20+
from ray._private.accelerators import tpu
2021
from ray.autoscaler._private.gcp.node import MAX_POLLS, POLL_INTERVAL, GCPNodeType
2122
from ray.autoscaler._private.util import check_legacy_fields
2223

@@ -51,11 +52,6 @@
5152
# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes
5253
# with ServiceAccounts.
5354

54-
# By default TPU VMs come with 4 chips per host and 2 tensorcores per chip.
55-
# For more details: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm
56-
DEFAULT_TPU_NUM_CHIPS_PER_HOST = 4
57-
DEFAULT_TPU_CORES_PER_CHIP = 2
58-
5955

6056
def tpu_accelerator_config_to_type(accelerator_config: dict) -> str:
6157
"""Convert a provided accelerator_config to accelerator_type.
@@ -75,17 +71,14 @@ def tpu_accelerator_config_to_type(accelerator_config: dict) -> str:
7571
# Reduce e.g. "2x2x2" to 8
7672
chip_dimensions = [int(chip_count) for chip_count in topology.split("x")]
7773
num_chips = reduce(lambda x, y: x * y, chip_dimensions)
78-
num_cores = num_chips * DEFAULT_TPU_CORES_PER_CHIP
7974

8075
# V5LitePod is rendered as "V5LITE_POD" in accelerator configuration but
8176
# accelerator type uses a format like "v5litepod-{cores}", so we need
8277
# to manually convert the string here.
8378
if generation == "v5lite_pod":
8479
generation = "v5litepod"
85-
num_cores = num_chips
8680

87-
if generation == "v6e":
88-
num_cores = num_chips
81+
num_cores = tpu.get_tpu_cores_per_chip(generation) * num_chips
8982

9083
return f"{generation}-{num_cores}"
9184

@@ -136,39 +129,13 @@ def _validate_tpu_config(node: dict):
136129
)
137130

138131

139-
def _get_num_tpu_visible_chips_per_host(accelerator_type: str) -> int:
140-
if accelerator_type == "v5litepod-8":
141-
return 8
142-
143-
# All V6e configurations have 8 chips per host
144-
if accelerator_type.startswith("v6e"):
145-
return 8
146-
147-
return DEFAULT_TPU_NUM_CHIPS_PER_HOST
148-
149-
150-
def _get_tpu_cores_per_chip(accelerator_type: str) -> int:
151-
# accelerator_type is in the form v{generateion}-{cores}
152-
accelerator_type = accelerator_type.split("-")[0]
153-
154-
# V5Litepods have 1 core per chip
155-
if accelerator_type == "v5litepod":
156-
return 1
157-
158-
# V6es have 1 core per chip
159-
if accelerator_type == "v6e":
160-
return 1
161-
162-
return DEFAULT_TPU_CORES_PER_CHIP
163-
164-
165132
def _get_num_tpu_chips(node: dict) -> int:
166133
chips = 0
167134
if "acceleratorType" in node:
168135
accelerator_type = node["acceleratorType"]
169136
# `acceleratorType` is typically v{generation}-{cores}
170137
cores = int(accelerator_type.split("-")[1])
171-
chips = cores / _get_tpu_cores_per_chip(accelerator_type)
138+
chips = cores / tpu.get_tpu_cores_per_chip(accelerator_type)
172139
if "acceleratorConfig" in node:
173140
topology = node["acceleratorConfig"]["topology"]
174141
# `topology` is typically {chips}x{chips}x{chips}
@@ -185,7 +152,7 @@ def _is_single_host_tpu(node: dict) -> bool:
185152
accelerator_type = node["acceleratorType"]
186153
else:
187154
accelerator_type = tpu_accelerator_config_to_type(node["acceleratorConfig"])
188-
return _get_num_tpu_chips(node) == _get_num_tpu_visible_chips_per_host(
155+
return _get_num_tpu_chips(node) <= tpu.get_num_tpu_visible_chips_per_host(
189156
accelerator_type
190157
)
191158

python/ray/tests/accelerators/test_tpu.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,21 @@ def test_empty_get_current_pod_name_returns_none():
282282
@pytest.mark.parametrize(
283283
"test_case",
284284
[
285-
(4, "v4-16", 4),
285+
# (number_chips_per_host, accl_type, expected_worker_count)
286+
(4, "v2-4", 1),
287+
(4, "v3-32", 4),
288+
(4, "v4-8", 1),
289+
(4, "v4-16", 2),
290+
(8, "v5litepod-4", 1),
291+
(8, "v5litepod-8", 1),
292+
(8, "v5litepod-16", 2),
293+
(8, "v5litepod-32", 4),
294+
(4, "v5p-4", 1),
295+
(4, "v5p-8", 1),
296+
(4, "v5p-16", 2),
297+
(8, "v6e-4", 1),
286298
(8, "v6e-8", 1),
299+
(8, "v6e-16", 2),
287300
],
288301
)
289302
@patch("glob.glob")

0 commit comments

Comments
 (0)