Skip to content

Commit 7070190

Browse files
feat(tpu): add tpu vm stop/start samples. (#9607)
* Changed package, added information to CODEOWNERS * Added information to CODEOWNERS * Added timeout * Fixed parameters for test * Fixed DeleteTpuVm and naming * Added comment, created Util class * Fixed naming * Fixed whitespace * Split PR into smaller, deleted redundant code * Implemented tpu_vm_stop and tpu_vm_start samples, created tests * Changed zone * Fixed empty lines and tests, deleted cleanup method * Fixed tests * Fixed comment
1 parent 437d4c7 commit 7070190

File tree

4 files changed

+162
-1
lines changed

4 files changed

+162
-1
lines changed

tpu/src/main/java/tpu/GetTpuVm.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class GetTpuVm {
2727

2828
public static void main(String[] args) throws IOException {
2929
// TODO(developer): Replace these variables before running the sample.
30-
// Project ID or project number of the Google Cloud project you want to create a node.
30+
// Project ID or project number of the Google Cloud project you want to use.
3131
String projectId = "YOUR_PROJECT_ID";
3232
// The zone in which to create the TPU.
3333
// For more information about supported TPU types for specific zones,

tpu/src/main/java/tpu/StartTpuVm.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_start]
20+
import com.google.cloud.tpu.v2.Node;
21+
import com.google.cloud.tpu.v2.NodeName;
22+
import com.google.cloud.tpu.v2.StartNodeRequest;
23+
import com.google.cloud.tpu.v2.TpuClient;
24+
import java.io.IOException;
25+
import java.util.concurrent.ExecutionException;
26+
27+
public class StartTpuVm {
28+
29+
public static void main(String[] args)
30+
throws IOException, ExecutionException, InterruptedException {
31+
// TODO(developer): Replace these variables before running the sample.
32+
// Project ID or project number of the Google Cloud project you want to use.
33+
String projectId = "YOUR_PROJECT_ID";
34+
// The zone where the TPU is located.
35+
// For more information about supported TPU types for specific zones,
36+
// see https://cloud.google.com/tpu/docs/regions-zones
37+
String zone = "us-central1-f";
38+
// The name for your TPU.
39+
String nodeName = "YOUR_TPU_NAME";
40+
41+
startTpuVm(projectId, zone, nodeName);
42+
}
43+
44+
// Starts a TPU VM with the specified name in the given project and zone.
45+
public static Node startTpuVm(String projectId, String zone, String nodeName)
46+
throws IOException, ExecutionException, InterruptedException {
47+
// Initialize client that will be used to send requests. This client only needs to be created
48+
// once, and can be reused for multiple requests.
49+
try (TpuClient tpuClient = TpuClient.create()) {
50+
String name = NodeName.of(projectId, zone, nodeName).toString();
51+
52+
StartNodeRequest request = StartNodeRequest.newBuilder().setName(name).build();
53+
54+
return tpuClient.startNodeAsync(request).get();
55+
}
56+
}
57+
}
58+
//[END tpu_vm_start]

tpu/src/main/java/tpu/StopTpuVm.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_stop]
20+
import com.google.cloud.tpu.v2.Node;
21+
import com.google.cloud.tpu.v2.NodeName;
22+
import com.google.cloud.tpu.v2.StopNodeRequest;
23+
import com.google.cloud.tpu.v2.TpuClient;
24+
import java.io.IOException;
25+
import java.util.concurrent.ExecutionException;
26+
27+
public class StopTpuVm {
28+
29+
public static void main(String[] args)
30+
throws IOException, ExecutionException, InterruptedException {
31+
// TODO(developer): Replace these variables before running the sample.
32+
// Project ID or project number of the Google Cloud project you want to use.
33+
String projectId = "YOUR_PROJECT_ID";
34+
// The zone where the TPU is located.
35+
// For more information about supported TPU types for specific zones,
36+
// see https://cloud.google.com/tpu/docs/regions-zones
37+
String zone = "us-central1-f";
38+
// The name for your TPU.
39+
String nodeName = "YOUR_TPU_NAME";
40+
41+
stopTpuVm(projectId, zone, nodeName);
42+
}
43+
44+
// Stops a TPU VM with the specified name in the given project and zone.
45+
public static Node stopTpuVm(String projectId, String zone, String nodeName)
46+
throws IOException, ExecutionException, InterruptedException {
47+
// Initialize client that will be used to send requests. This client only needs to be created
48+
// once, and can be reused for multiple requests.
49+
try (TpuClient tpuClient = TpuClient.create()) {
50+
String name = NodeName.of(projectId, zone, nodeName).toString();
51+
52+
StopNodeRequest request = StopNodeRequest.newBuilder().setName(name).build();
53+
54+
return tpuClient.stopNodeAsync(request).get();
55+
}
56+
}
57+
}
58+
//[END tpu_vm_stop]
59+

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import com.google.cloud.tpu.v2.GetNodeRequest;
3333
import com.google.cloud.tpu.v2.ListNodesRequest;
3434
import com.google.cloud.tpu.v2.Node;
35+
import com.google.cloud.tpu.v2.StartNodeRequest;
36+
import com.google.cloud.tpu.v2.StopNodeRequest;
3537
import com.google.cloud.tpu.v2.TpuClient;
3638
import com.google.cloud.tpu.v2.TpuSettings;
3739
import java.io.ByteArrayOutputStream;
@@ -166,4 +168,46 @@ public void testListTpuVm() throws IOException {
166168
verify(mockTpuClient, times(1)).listNodes(any(ListNodesRequest.class));
167169
}
168170
}
171+
172+
@Test
173+
public void testStartTpuVm() throws IOException, ExecutionException, InterruptedException {
174+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
175+
TpuClient mockClient = mock(TpuClient.class);
176+
Node mockNode = mock(Node.class);
177+
OperationFuture mockFuture = mock(OperationFuture.class);
178+
179+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
180+
when(mockClient.startNodeAsync(any(StartNodeRequest.class)))
181+
.thenReturn(mockFuture);
182+
when(mockFuture.get()).thenReturn(mockNode);
183+
184+
Node returnedNode = StartTpuVm.startTpuVm(PROJECT_ID, ZONE, NODE_NAME);
185+
186+
verify(mockClient, times(1))
187+
.startNodeAsync(any(StartNodeRequest.class));
188+
verify(mockFuture, times(1)).get();
189+
assertEquals(returnedNode, mockNode);
190+
}
191+
}
192+
193+
@Test
194+
public void testStopTpuVm() throws IOException, ExecutionException, InterruptedException {
195+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
196+
TpuClient mockClient = mock(TpuClient.class);
197+
Node mockNode = mock(Node.class);
198+
OperationFuture mockFuture = mock(OperationFuture.class);
199+
200+
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
201+
when(mockClient.stopNodeAsync(any(StopNodeRequest.class)))
202+
.thenReturn(mockFuture);
203+
when(mockFuture.get()).thenReturn(mockNode);
204+
205+
Node returnedNode = StopTpuVm.stopTpuVm(PROJECT_ID, ZONE, NODE_NAME);
206+
207+
verify(mockClient, times(1))
208+
.stopNodeAsync(any(StopNodeRequest.class));
209+
verify(mockFuture, times(1)).get();
210+
assertEquals(returnedNode, mockNode);
211+
}
212+
}
169213
}

0 commit comments

Comments
 (0)