Skip to content

Commit dc30570

Browse files
feat: New Samples for TPU: Start, Stop, List (#12702)
* Added Samples for TPU: List, Start, Stop * Updated auto-label.yaml
1 parent a458bc8 commit dc30570

File tree

5 files changed

+197
-0
lines changed

5 files changed

+197
-0
lines changed

.github/auto-label.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ path:
8282
storagetransfer: "storagetransfer"
8383
talent: "jobs"
8484
texttospeech: "texttospeech"
85+
tpu: "tpu"
8586
trace: "cloudtrace"
8687
translate: "translate"
8788
vision: "vision"

tpu/list_tpu.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud.tpu_v2.services.tpu.pagers import ListNodesPager
18+
19+
20+
def list_cloud_tpu(project_id: str, zone: str) -> ListNodesPager:
21+
"""List all TPU nodes in the project and zone.
22+
Args:
23+
project_id (str): The ID of Google Cloud project.
24+
zone (str): The zone of the TPU nodes.
25+
Returns:
26+
ListNodesPager: The list of TPU nodes.
27+
"""
28+
# [START tpu_vm_list]
29+
from google.cloud import tpu_v2
30+
31+
# TODO(developer): Update and un-comment below lines
32+
# project_id = "your-project-id"
33+
# zone = "us-central1-b"
34+
35+
client = tpu_v2.TpuClient()
36+
37+
nodes = client.list_nodes(parent=f"projects/{project_id}/locations/{zone}")
38+
for node in nodes:
39+
print(node.name)
40+
print(node.state)
41+
print(node.accelerator_type)
42+
# Example response:
43+
# projects/[project_id]/locations/[zone]/nodes/node-name
44+
# State.READY
45+
# v2-8
46+
47+
# [END tpu_vm_list]
48+
return nodes
49+
50+
51+
if __name__ == "__main__":
52+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
53+
ZONE = "us-central1-b"
54+
list_cloud_tpu(PROJECT_ID, ZONE)

tpu/start_tpu.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud.tpu_v2 import Node
18+
19+
20+
def start_cloud_tpu(project_id: str, zone: str, tpu_name: str = "tpu-name") -> Node:
21+
"""Start a TPU node.
22+
Args:
23+
project_id (str): The ID of Google Cloud project.
24+
zone (str): The zone of the TPU node.
25+
tpu_name (str): The name of the TPU node to start.
26+
Returns:
27+
Node: The started TPU node.
28+
"""
29+
# [START tpu_vm_start]
30+
from google.cloud import tpu_v2
31+
32+
# TODO(developer): Update and un-comment below lines
33+
# project_id = "your-project-id"
34+
# zone = "us-central1-b"
35+
# tpu_name = "tpu-name"
36+
37+
client = tpu_v2.TpuClient()
38+
39+
request = tpu_v2.StartNodeRequest(
40+
name=f"projects/{project_id}/locations/{zone}/nodes/{tpu_name}",
41+
)
42+
try:
43+
operation = client.start_node(request=request)
44+
print("Waiting for start operation to complete...")
45+
response = operation.result()
46+
print(f"TPU {tpu_name} has been started")
47+
print(response.state)
48+
# Example response:
49+
# State.READY
50+
51+
return response
52+
except Exception as e:
53+
print(e)
54+
55+
# [END tpu_vm_start]
56+
57+
58+
if __name__ == "__main__":
59+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
60+
ZONE = "us-central1-b"
61+
start_cloud_tpu(PROJECT_ID, ZONE, "tpu-name")

tpu/stop_tpu.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud.tpu_v2 import Node
18+
19+
20+
def stop_cloud_tpu(project_id: str, zone: str, tpu_name: str = "tpu-name") -> Node:
21+
"""Stop a TPU node.
22+
Args:
23+
project_id (str): The ID of Google Cloud project.
24+
zone (str): The zone of the TPU node.
25+
tpu_name (str): The name of the TPU node to stop.
26+
Returns:
27+
Node: The stopped TPU node.
28+
"""
29+
30+
# [START tpu_vm_stop]
31+
from google.cloud import tpu_v2
32+
33+
# TODO(developer): Update and un-comment below lines
34+
# project_id = "your-project-id"
35+
# zone = "us-central1-b"
36+
# tpu_name = "tpu-name"
37+
38+
client = tpu_v2.TpuClient()
39+
40+
request = tpu_v2.StopNodeRequest(
41+
name=f"projects/{project_id}/locations/{zone}/nodes/{tpu_name}",
42+
)
43+
try:
44+
operation = client.stop_node(request=request)
45+
print("Waiting for stop operation to complete...")
46+
response = operation.result()
47+
print(f"This TPU {tpu_name} has been stopped")
48+
print(response.state)
49+
# Example response:
50+
# State.STOPPED
51+
52+
return response
53+
except Exception as e:
54+
print(e)
55+
56+
# [END tpu_vm_stop]
57+
58+
59+
if __name__ == "__main__":
60+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
61+
ZONE = "us-central1-b"
62+
stop_cloud_tpu(PROJECT_ID, ZONE, "tpu-name")

tpu/test_tpu.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
import create_tpu_with_script
2323
import delete_tpu
2424
import get_tpu
25+
import list_tpu
26+
import start_tpu
27+
import stop_tpu
2528

2629
TPU_NAME = "test-tpu-" + uuid.uuid4().hex[:10]
2730
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
@@ -52,10 +55,26 @@ def test_creating_with_startup_script() -> None:
5255
)
5356
assert "--upgrade numpy" in tpu_with_script.metadata["startup-script"]
5457
finally:
58+
print(f"\n\n ------------ Deleting TPU {TPU_NAME}\n ------------")
5559
delete_tpu.delete_cloud_tpu(PROJECT_ID, ZONE, tpu_name_with_script)
5660

5761

5862
def test_get_tpu() -> None:
5963
tpu = get_tpu.get_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
6064
assert tpu.state == Node.State.READY
6165
assert tpu.name == f"projects/{PROJECT_ID}/locations/{ZONE}/nodes/{TPU_NAME}"
66+
67+
68+
def test_list_tpu() -> None:
69+
nodes = list_tpu.list_cloud_tpu(PROJECT_ID, ZONE)
70+
assert len(list(nodes)) > 0
71+
72+
73+
def test_stop_tpu() -> None:
74+
node = stop_tpu.stop_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
75+
assert node.state == Node.State.STOPPED
76+
77+
78+
def test_start_tpu() -> None:
79+
node = start_tpu.start_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
80+
assert node.state == Node.State.READY

0 commit comments

Comments
 (0)