10000 feat(tpu): add tpu queued resources create spot (#9615) · invertase/java-docs-samples@38eb8d3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 38eb8d3

Browse files
feat(tpu): add tpu queued resources create spot (GoogleCloudPlatform#9615)
Add a code sample for tpu_queued_resources_create_spot
1 parent 772c39c commit 38eb8d3

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
// [START tpu_queued_resources_create_spot]
20+
import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest;
21+
import com.google.cloud.tpu.v2alpha1.Node;
22+
import com.google.cloud.tpu.v2alpha1.QueuedResource;
23+
import com.google.cloud.tpu.v2alpha1.SchedulingConfig;
24+
import com.google.cloud.tpu.v2alpha1.TpuClient;
25+
import java.io.IOException;
26+
import java.util.concurrent.ExecutionException;
27+
28+
public class CreateSpotQueuedResource {
29+
public static void main(String[] args)
30+
throws IOException, ExecutionException, InterruptedException {
31+
// TODO(developer): Replace these variables before running the sample.
32+
// Project ID or project number of the Google Cloud project you want to create a node.
33+
String projectId = "YOUR_PROJECT_ID";
34+
// The zone in which to create the TPU.
35+
// For more information about supported TPU types for specific zones,
36+
// see https://cloud.google.com/tpu/docs/regions-zones
37+
String zone = "us-central1-f";
38+
// The name for your TPU.
39+
String nodeName = "YOUR_TPU_NAME";
40+
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
41+
// For more information about supported accelerator types for each TPU version,
42+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
43+
String tpuType = "v2-8";
44+
// Software version that specifies the version of the TPU runtime to install.
45+
// For more information see https://cloud.google.com/tpu/docs/runtimes
46+
String tpuSoftwareVersion = "tpu-vm-tf-2.14.1";
47+
// The name for your Queued Resource.
48+
String queuedResourceId = "QUEUED_RESOURCE_ID";
49+
50+
createQueuedResource(
51+
projectId, zone, queuedResourceId, nodeName, tpuType, tpuSoftwareVersion);
52+
}
53+
54+
// Creates a Queued Resource with --preemptible flag.
55+
public static QueuedResource createQueuedResource(
56+
String projectId, String zone, String queuedResourceId,
57+
String nodeName, String tpuType, String tpuSoftwareVersion)
58+
throws IOException, ExecutionException, InterruptedException {
59+
// Initialize client that will be used to send requests. This client only needs to be created
60+
// once, and can be reused for multiple requests.
61+
try (TpuClient tpuClient = TpuClient.create()) {
62+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
63+
String resourceName = String.format("projects/%s/locations/%s/queuedResources/%s",
64+
projectId, zone, queuedResourceId);
65+
SchedulingConfig schedulingConfig = SchedulingConfig.newBuilder()
66+
.setPreemptible(true)
67+
.build();
68+
69+
Node node =
70+
Node.newBuilder()
71+
.setName(nodeName)
72+
.setAcceleratorType(tpuType)
73+
.setRuntimeVersion(tpuSoftwareVersion)
74+
.setSchedulingConfig(schedulingConfig)
75+
.setQueuedResource(resourceName)
76+
.build();
77+
78+
QueuedResource queuedResource =
79+
QueuedResource.newBuilder()
80+
.setName(queuedResourceId)
81+
.setTpu(
82+
QueuedResource.Tpu.newBuilder()
83+
.addNodeSpec(
84+
QueuedResource.Tpu.NodeSpec.newBuilder()
85+
.setParent(parent)
86+
.setNode(node)
87+
.setNodeId(nodeName)
88+
.build())
89+
.build())
90+
.build();
91+
92+
CreateQueuedResourceRequest request =
93+
CreateQueuedResourceRequest.newBuilder()
94+
.setParent(parent)
95+
.setQueuedResourceId(queuedResourceId)
96+
.setQueuedResource(queuedResource)
97+
.build();
98+
99+
return tpuClient.createQueuedResourceAsync(request).get();
100+
}
101+
}
102+
}
103+
// [END tpu_queued_resources_create_spot]

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,30 @@ public void testCreateQueuedResourceWithStartupScript() throws Exception {
181181
assertEquals(returnedQueuedResource, mockQueuedResource);
182182
}
183183
}
184+
185+
@Test
186+
public void testCreateSpotQueuedResource() throws Exception {
187+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
188+
QueuedResource mockQueuedResource = QueuedResource.newBuilder()
189+
.setName("QueuedResourceName")
190+
.build();
191+
TpuClient mockedClientInstance = mock(TpuClient.class);
192+
OperationFuture mockFuture = mock(OperationFuture.class);
193+
194+
mockedTpuClient.when(TpuClient::create).thenReturn(mockedClientInstance);
195+
when(mockedClientInstance.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
196+
.thenReturn(mockFuture);
197+
when(mockFuture.get()).thenReturn(mockQueuedResource);
198+
199+
QueuedResource returnedQueuedResource =
200+
CreateSpotQueuedResource.createQueuedResource(
201+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
202+
TPU_TYPE, TPU_SOFTWARE_VERSION);
203+
204+
verify(mockedClientInstance, times(1))
205+
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
206+
verify(mockFuture, times(1)).get();
207+
assertEquals(returnedQueuedResource.getName(), mockQueuedResource.getName());
208+
}
209+
}
184210
}

0 commit comments

Comments
 (0)
0