Skip to content

Commit f7199c2

Browse files
Created TPU Topology Sample (#12719)
1 parent 64cf1d5 commit f7199c2

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

tpu/create_tpu_topology.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud.tpu_v2 import Node
18+
19+
20+
def create_cloud_tpu_with_topology(
21+
project_id: str,
22+
zone: str,
23+
tpu_name: str,
24+
runtime_version: str = "tpu-vm-tf-2.17.0-pjrt",
25+
) -> Node:
26+
"""Creates a Cloud TPU node with a specific topology.
27+
Args:
28+
project_id (str): The ID of the Google Cloud project.
29+
zone (str): The zone where the TPU node will be created.
30+
tpu_name (str): The name of the TPU node.
31+
runtime_version (str, optional): The runtime version for the TPU.
32+
Returns:
33+
Node: The created TPU node.
34+
"""
35+
# [START tpu_vm_create_topology]
36+
from google.cloud import tpu_v2
37+
38+
# TODO(developer): Update and un-comment below lines
39+
# project_id = "your-project-id"
40+
# zone = "us-central1-b"
41+
# tpu_name = "tpu-name"
42+
# runtime_version = "tpu-vm-tf-2.17.0-pjrt"
43+
44+
node = tpu_v2.Node()
45+
# Here we are creating a TPU v3-8 with 2x2 topology.
46+
node.accelerator_config = tpu_v2.AcceleratorConfig(
47+
type_=tpu_v2.AcceleratorConfig.Type.V3,
48+
topology="2x2",
49+
)
50+
node.runtime_version = runtime_version
51+
52+
request = tpu_v2.CreateNodeRequest(
53+
parent=f"projects/{project_id}/locations/{zone}",
54+
node_id=tpu_name,
55+
node=node,
56+
)
57+
58+
client = tpu_v2.TpuClient()
59+
operation = client.create_node(request=request)
60+
print("Waiting for operation to complete...")
61+
62+
response = operation.result()
63+
print(response.accelerator_config)
64+
# Example response:
65+
# type_: V3
66+
# topology: "2x2"
67+
68+
# [END tpu_vm_create_topology]
69+
return response
70+
71+
72+
if __name__ == "__main__":
73+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
74+
ZONE = "us-central1-a"
75+
create_cloud_tpu_with_topology(PROJECT_ID, ZONE, "tpu-name")

tpu/test_tpu.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,20 @@
1414
import os
1515
import uuid
1616

17-
from google.cloud.tpu_v2.types import Node
17+
from google.cloud.tpu_v2.types import AcceleratorConfig, Node
1818

1919
import pytest
2020

2121
import create_tpu
22+
import create_tpu_topology
2223
import create_tpu_with_script
2324
import delete_tpu
2425
import get_tpu
2526
import list_tpu
2627
import start_tpu
2728
import stop_tpu
2829

30+
2931
TPU_NAME = "test-tpu-" + uuid.uuid4().hex[:10]
3032
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
3133
ZONE = "us-south1-a"
@@ -78,3 +80,16 @@ def test_stop_tpu() -> None:
7880
def test_start_tpu() -> None:
7981
node = start_tpu.start_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
8082
assert node.state == Node.State.READY
83+
84+
85+
def test_with_topology() -> None:
86+
topology_tpu_name = "topology-tpu-" + uuid.uuid4().hex[:5]
87+
topology_zone = "us-central1-a"
88+
try:
89+
topology_tpu = create_tpu_topology.create_cloud_tpu_with_topology(
90+
PROJECT_ID, topology_zone, topology_tpu_name, TPU_VERSION
91+
)
92+
assert topology_tpu.accelerator_config.type_ == AcceleratorConfig.Type.V3
93+
assert topology_tpu.accelerator_config.topology == "2x2"
94+
finally:
95+
delete_tpu.delete_cloud_tpu(PROJECT_ID, topology_zone, topology_tpu_name)

0 commit comments

Comments
 (0)