Skip to content

Commit 05b16a2

Browse files
feat(tpu): add tpu vm create startup script sample. (#9612)
* Changed package, added information to CODEOWNERS * Added information to CODEOWNERS * Added timeout * Fixed parameters for test * Fixed DeleteTpuVm and naming * Added comment, created Util class * Fixed naming * Fixed whitespace * Split PR into smaller, deleted redundant code * Implemented tpu_vm_create_startup_script sample, created test * Fixed tests and empty lines * Changed zone * Deleted redundant test classes * Increased timeout * Fixed code
1 parent e6a4c3f commit 05b16a2

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_create_startup_script]
20+
import com.google.cloud.tpu.v2.CreateNodeRequest;
21+
import com.google.cloud.tpu.v2.Node;
22+
import com.google.cloud.tpu.v2.TpuClient;
23+
import java.io.IOException;
24+
import java.util.HashMap;
25+
import java.util.Map;
26+
import java.util.concurrent.ExecutionException;
27+
28+
public class CreateTpuVmWithStartupScript {
29+
public static void main(String[] args)
30+
throws IOException, ExecutionException, InterruptedException {
31+
// TODO(developer): Replace these variables before running the sample.
32+
// Project ID or project number of the Google Cloud project you want to create a node.
33+
String projectId = "YOUR_PROJECT_ID";
34+
// The zone in which to create the TPU.
35+
// For more information about supported TPU types for specific zones,
36+
// see https://cloud.google.com/tpu/docs/regions-zones
37+
String zone = "us-central1-f";
38+
// The name for your TPU.
39+
String nodeName = "YOUR_TPU_NAME";
40+
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
41+
// For more information about supported accelerator types for each TPU version,
42+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
43+
String acceleratorType = "v2-8";
44+
// Software version that specifies the version of the TPU runtime to install.
45+
// For more information, see https://cloud.google.com/tpu/docs/runtimes
46+
String tpuSoftwareVersion = "tpu-vm-tf-2.14.1";
47+
48+
createTpuVmWithStartupScript(projectId, zone, nodeName, acceleratorType, tpuSoftwareVersion);
49+
}
50+
51+
// Create a TPU VM with a startup script.
52+
public static Node createTpuVmWithStartupScript(String projectId, String zone,
53+
String nodeName, String acceleratorType, String tpuSoftwareVersion)
54+
throws IOException, ExecutionException, InterruptedException {
55+
// Initialize client that will be used to send requests. This client only needs to be created
56+
// once, and can be reused for multiple requests.
57+
try (TpuClient tpuClient = TpuClient.create()) {
58+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
59+
60+
String startupScriptContent = "#!/bin/bash\necho \"Hello from the startup script!\"";
61+
// Add startup script to metadata
62+
Map<String, String> metadata = new HashMap<>();
63+
metadata.put("startup-script", startupScriptContent);
64+
65+
Node tpuVm =
66+
Node.newBuilder()
67+
.setName(nodeName)
68+
.setAcceleratorType(acceleratorType)
69+
.setRuntimeVersion(tpuSoftwareVersion)
70+
.putAllMetadata(metadata)
71+
.build();
72+
73+
CreateNodeRequest request =
74+
CreateNodeRequest.newBuilder()
75+
.setParent(parent)
76+
.setNodeId(nodeName)
77+
.setNode(tpuVm)
78+
.build();
79+
80+
return tpuClient.createNodeAsync(request).get();
81+
}
82+
}
83+
}
84+
//[END tpu_vm_create_startup_script]

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,34 @@ public void testCreateSpotTpuVm() throws Exception {
233233
assertEquals(returnedNode, mockNode);
234234
}
235235
}
236+
237+
@Test
238+
public void testCreateTpuVmWithStartupScript() throws Exception {
239+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
240+
Node mockNode = Node.newBuilder()
241+
.setName("nodeName")
242+
.setAcceleratorType("acceleratorType")
243+
.setRuntimeVersion("runtimeVersion")
244+
.build();
245+
246+
TpuClient mockTpuClient = mock(TpuClient.class);
247+
OperationFuture mockFuture = mock(OperationFuture.class);
248+
249+
mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient);
250+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
251+
.thenReturn(mockFuture);
252+
when(mockFuture.get()).thenReturn(mockNode);
253+
254+
Node returnedNode = CreateTpuVmWithStartupScript.createTpuVmWithStartupScript(
255+
PROJECT_ID, ZONE, NODE_NAME,
256+
TPU_TYPE, TPU_SOFTWARE_VERSION);
257+
258+
verify(mockTpuClient, times(1))
259+
.createNodeAsync(any(CreateNodeRequest.class));
260+
verify(mockFuture, times(1)).get();
261+
assertEquals(returnedNode.getName(), mockNode.getName());
262+
assertEquals(returnedNode.getAcceleratorType(), mockNode.getAcceleratorType());
263+
assertEquals(returnedNode.getRuntimeVersion(), mockNode.getRuntimeVersion());
264+
}
265+
}
236266
}

0 commit comments

Comments
 (0)