8000 feat(aiplatform): add gemma2 samples for Model Garden deployments to … · ludoch/java-docs-samples@39ee185 · GitHub
[go: up one dir, main page]

Skip to content

Commit 39ee185

Browse files
feat(aiplatform): add gemma2 samples for Model Garden deployments to VertexAI endpoints (GoogleCloudPlatform#9527)
* Added generativeaionvertexai_gemma2_predict_tpu and generativeaionvertexai_gemma2_predict_gpu sample, created test * Fixed instance format * Fixed test and instance format for Gemma2PredictTpu * Deleted class for vertexai package * Added generativeaionvertexai_gemma2_predict_gpu and generativeaionvertexai_gemma2_predict_tpu samples, created test * Fixed comments * added comments, created new test * Fixed parameters, created test to check parameters * Fixed comments * Fixed comments
1 parent 611b32b commit 39ee185

File tree

5 files changed

+384
-0
lines changed

5 files changed

+384
-0
lines changed

aiplatform/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@
8989
<version>1.7.1</version>
9090
<scope>test</scope>
9191
</dependency>
92+
<dependency>
93+
<groupId>org.mockito</groupId>
94+
<artifactId>mockito-core</artifactId>
95+
<version>5.13.0</version>
96+
<scope>test</scope>
97+
</dependency>
9298
<dependency>
9399
<groupId>org.junit.jupiter</groupId>
94100
<artifactId>junit-jupiter</artifactId>
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START generativeaionvertexai_gemma2_predict_gpu]
20+
21+
import com.google.cloud.aiplatform.v1.EndpointName;
22+
import com.google.cloud.aiplatform.v1.PredictResponse;
23+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
25+
import com.google.gson.Gson;
26+
import com.google.protobuf.InvalidProtocolBufferException;
27+
import com.google.protobuf.Value;
28+
import com.google.protobuf.util.JsonFormat;
29+
import java.io.IOException;
30+
import java.util.ArrayList;
31+
import java.util.HashMap;
32+
import java.util.List;
33+
import java.util.Map;
34+
35+
public class Gemma2PredictGpu {
36+
37+
private final PredictionServiceClient predictionServiceClient;
38+
39+
// Constructor to inject the PredictionServiceClient
40+
public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) {
41+
this.predictionServiceClient = predictionServiceClient;
42+
}
43+
44+
public static void main(String[] args) throws IOException {
45+
// TODO(developer): Replace these variables before running the sample.
46+
String projectId = "YOUR_PROJECT_ID";
47+
String endpointRegion = "us-east4";
48+
String endpointId = "YOUR_ENDPOINT_ID";
49+
50+
PredictionServiceSettings predictionServiceSettings =
51+
PredictionServiceSettings.newBuilder()
52+
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
53+
.build();
54+
PredictionServiceClient predictionServiceClient =
55+
PredictionServiceClient.create(predictionServiceSettings);
56+
Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient);
57+
58+
creator.gemma2PredictGpu(projectId, endpointRegion, endpointId);
59+
}
60+
61+
// Demonstrates how to run inference on a Gemma2 model
62+
// deployed to a Vertex AI endpoint with GPU accelerators.
63+
public String gemma2PredictGpu(String projectId, String region,
64+
String endpointId) throws IOException {
65+
Map<String, Object> paramsMap = new HashMap<>();
66+
paramsMap.put("temperature", 0.9);
67+
paramsMap.put("maxOutputTokens", 1024);
68+
paramsMap.put("topP", 1.0);
69+
paramsMap.put("topK", 1);
70+
Value parameters = mapToValue(paramsMap);
71+
72+
// Prompt used in the prediction
73+
String instance = "{ \"inputs\": \"Why is the sky blue?\"}";
74+
Value.Builder instanceValue = Value.newBuilder();
75+
JsonFormat.parser().merge(instance, instanceValue);
76+
// Encapsulate the prompt in a correct format for GPUs
77+
// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}]
78+
List<Value> instances = new ArrayList<>();
79+
instances.add(instanceValue.build());
80+
81+
EndpointName endpointName = EndpointName.of(projectId, region, endpointId);
82+
83+
PredictResponse predictResponse = this.predictionServiceClient
84+
.predict(endpointName, instances, parameters);
85+
String textResponse = predictResponse.getPredictions(0).getStringValue();
86+
System.out.println(textResponse);
87+
return textResponse;
88+
}
89+
90+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
91+
Gson gson = new Gson();
92+
String json = gson.toJson(map);
93+
Value.Builder builder = Value.newBuilder();
94+
JsonFormat.parser().merge(json, builder);
95+
return builder.build();
96+
}
97+
}
98+
// [END generativeaionvertexai_gemma2_predict_gpu]
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START generativeaionvertexai_gemma2_predict_tpu]
20+
21+
import com.google.cloud.aiplatform.v1.EndpointName;
22+
import com.google.cloud.aiplatform.v1.PredictResponse;
23+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
25+
import com.google.gson.Gson;
26+
import com.google.protobuf.InvalidProtocolBufferException;
27+
import com.google.protobuf.Value;
28+
import com.google.protobuf.util.JsonFormat;
29+
import java.io.IOException;
30+
import java.util.ArrayList;
31+
import java.util.HashMap;
32+
import java.util.List;
33+
import java.util.Map;
34+
35+
public class Gemma2PredictTpu {
36+
private final PredictionServiceClient predictionServiceClient;
37+
38+
// Constructor to inject the PredictionServiceClient
39+
public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
40+
this.predictionServiceClient = predictionServiceClient;
41+
}
42+
43+
public static void main(String[] args) throws IOException {
44+
// TODO(developer): Replace these variables before running the sample.
45+
St 10669 ring projectId = "YOUR_PROJECT_ID";
46+
String endpointRegion = "us-west1";
47+
String endpointId = "YOUR_ENDPOINT_ID";
48+
49+
PredictionServiceSettings predictionServiceSettings =
50+
PredictionServiceSettings.newBuilder()
51+
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
52+
.build();
53+
PredictionServiceClient predictionServiceClient =
54+
PredictionServiceClient.create(predictionServiceSettings);
55+
Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);
56+
57+
creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
58+
}
59+
60+
// Demonstrates how to run inference on a Gemma2 model
61+
// deployed to a Vertex AI endpoint with TPU accelerators.
62+
public String gemma2PredictTpu(String projectId, String region,
63+
String endpointId) throws IOException {
64+
Map<String, Object> paramsMap = new HashMap<>();
65+
paramsMap.put("temperature", 0.9);
66+
paramsMap.put("maxOutputTokens", 1024);
67+
paramsMap.put("topP", 1.0);
68+
paramsMap.put("topK", 1);
69+
Value parameters = mapToValue(paramsMap);
70+
// Prompt used in the prediction
71+
String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
72+
Value.Builder instanceValue = Value.newBuilder();
73+
JsonFormat.parser().merge(instance, instanceValue);
74+
// Encapsulate the prompt in a correct format for TPUs
75+
// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
76+
List<Value> instances = new ArrayList<>();
77+
instances.add(instanceValue.build());
78+
79+
EndpointName endpointName = EndpointName.of(projectId, region, endpointId);
80+
81+
PredictResponse predictResponse = this.predictionServiceClient
82+
.predict(endpointName, instances, parameters);
83+
String textResponse = predictResponse.getPredictions(0).getStringValue();
84+
System.out.println(textResponse);
85+
return textResponse;
86+
}
87+
88+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
89+
Gson gson = new Gson();
90+
String json = gson.toJson(map);
91+
Value.Builder builder = Value.newBuilder();
92+
JsonFormat.parser().merge(json, builder);
93+
return builder.build();
94+
}
95+
}
96+
// [END generativeaionvertexai_gemma2_predict_tpu]
97+
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import static org.junit.jupiter.api.Assertions.assertTrue;
20+
21+
import com.google.cloud.aiplatform.v1.EndpointName;
22+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
23+
import com.google.gson.Gson;
24+
import com.google.protobuf.InvalidProtocolBufferException;
25+
import com.google.protobuf.Value;
26+
import com.google.protobuf.util.JsonFormat;
27+
import java.util.ArrayList;
28+
import java.util.HashMap;
29+
import java.util.List;
30+
import java.util.Map;
31+
import org.junit.jupiter.api.Test;
32+
import org.mockito.Mockito;
33+
import org.mockito.stubbing.Answer;
34+
35+
public class Gemma2ParametersTest {
36+
37+
static PredictionServiceClient mockGpuPredictionServiceClient;
38+
static PredictionServiceClient mockTpuPredictionServiceClient;
39+
private static final String INSTANCE_GPU = "{ \"inputs\": \"Why is the sky blue?\"}";
40+
private static final String INSTANCE_TPU = "{ \"prompt\": \"Why is the sky blue?\"}";
41+
42+
@Test
43+
public void parametersTest() throws InvalidProtocolBufferException {
44+
// Mock GPU and TPU PredictionServiceClient and its response
45+
mockGpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
46+
mockTpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
47+
48+
Value.Builder instanceValueGpu = Value.newBuilder();
49+
JsonFormat.parser().merge(INSTANCE_GPU, instanceValueGpu);
50+
List<Value> instancesGpu = new ArrayList<>();
51+
instancesGpu.add(instanceValueGpu.build());
52+
53+
Value.Builder instanceValueTpu = Value.newBuilder();
54+
JsonFormat.parser().merge(INSTANCE_TPU, instanceValueTpu);
55+
List<Value> instancesTpu = new ArrayList<>();
56+
instancesTpu.add(instanceValueTpu.build());
57+
58+
Map<String, Object> paramsMap = new HashMap<>();
59+
paramsMap.put("temperature", 0.9);
60+
paramsMap.put("maxOutputTokens", 1024);
61+
paramsMap.put("topP", 1.0);
62+
paramsMap.put("topK", 1);
63+
Value parameters = mapToValue(paramsMap);
64+
65+
Mockito.when(mockGpuPredictionServiceClient.predict(
66+
Mockito.any(EndpointName.class),
67+
Mockito.any(List.class),
68+
Mockito.any(Value.class)))
69+
.thenAnswer(invocation ->
70+
mockGpuResponse(instancesGpu, parameters));
71+
72+
Mockito.when(mockTpuPredictionServiceClient.predict(
73+
Mockito.any(EndpointName.class),
74+
Mockito.any(List.class),
75+
Mockito.any(Value.class)))
76+
.thenAnswer(invocation ->
77+
mockTpuResponse(instancesTpu, parameters));
78+
}
79+
80+
public static Answer<?> mockGpuResponse(List<Value> instances, Value parameter) {
81+
82+
assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("inputs"));
83+
assertTrue(parameter.getStructValue().containsFields("temperature"));
84+
assertTrue(parameter.getStructValue().containsFields("maxOutputTokens"));
85+
assertTrue(parameter.getStructValue().containsFields("topP"));
86+
assertTrue(parameter.getStructValue().containsFields("topK"));
87+
return null;
88+
}
89+
90+
public static Answer<?> mockTpuResponse(List<Value> instances, Value parameter) {
91+
92+
assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("prompt"));
93+
assertTrue(parameter.getStructValue().containsFields("temperature"));
94+
assertTrue(parameter.getStructValue().containsFields("maxOutputTokens"));
95+
assertTrue(parameter.getStructValue().containsFields("topP"));
96+
assertTrue(parameter.getStructValue().containsFields("topK"));
97+
return null;
98+
}
99+
100+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
101+
Gson gson = new Gson();
102+
String json = gson.toJson(map);
103+
Value.Builder builder = Value.newBuilder();
104+
JsonFormat.parser().merge(json, builder);
105+
return builder.build();
106+
}
107+
}
108+
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
21+
import com.google.cloud.aiplatform.v1.EndpointName;
22+
import com.google.cloud.aiplatform.v1.PredictResponse;
23+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
24+
import com.google.protobuf.Value;
25+
import java.io.IOException;
26+
import java.util.List;
27+
import org.junit.jupiter.api.BeforeAll;
28+
import org.junit.jupiter.api.Test;
29+
import org.mockito.Mockito;
30+
31+
public class Gemma2PredictTest {
32+
static String mockedResponse = "The sky appears blue due to a phenomenon "
33+
+ "called **Rayleigh scattering**.\n"
34+
+ "**Here's how it works:**\n"
35+
+ "* **Sunlight is white:** Sunlight actually contains all the colors of the rainbow.\n"
36+
+ "* **Scattering:** When sunlight enters the Earth's atmosphere, it collides with tiny gas"
37+
+ " molecules (mostly nitrogen and oxygen). These collisions cause the light to scatter "
38+
+ "in different directions.\n"
39+
+ "* **Blue light scatters most:** Blue light has a shorter wavelength";
40+
String projectId = "your-project-id";
41+
String region = "us-central1";
42+
String endpointId = "your-endpoint-id";
43+
static PredictionServiceClient mockPredictionServiceClient;
44+
45+
@BeforeAll
46+
public static void setUp() {
47+
// Mock PredictionServiceClient and its response
48+
mockPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
49+
PredictResponse predictResponse =
50+
PredictResponse.newBuilder()
51+
.addPredictions(Value.newBuilder().setStringValue(mockedResponse).build())
52+
.build();
53+
Mockito.when(mockPredictionServiceClient.predict(
54+
Mockito.any(EndpointName.class),
55+
Mockito.any(List.class),
56+
Mockito.any(Value.class)))
57+
.thenReturn(predictResponse);
58+
}
59+
60+
@Test
61+
public void testGemma2PredictTpu() throws IOException {
62+
Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient);
63+
String response = creator.gemma2PredictTpu(projectId, region, endpointId);
64+
65+
assertEquals(mockedResponse, response);
66+
}
67+
68+
@Test
69+
public void testGemma2PredictGpu() throws IOException {
70+
Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient);
71+
String response = creator.gemma2PredictGpu(projectId, region, endpointId);
72+
73+
assertEquals(mockedResponse, response);
74+
}
75+
}

0 commit comments

Comments
 (0)
0