8000 feat: text embeddings sample with Vertex Gen AI (#8228) · nes-a-cti/java-docs-samples@243a471 · GitHub
[go: up one dir, main page]

Skip to content

Commit 243a471

Browse files
authored
feat: text embeddings sample with Vertex Gen AI (GoogleCloudPlatform#8228)
* feat: text embeddings sample with Vertex Gen AI * addressed review comments * addressed review comments (overwrote the code formatter) * add retries to mitigate unstable tests
1 parent 51ed511 commit 243a471

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright 2023 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START aiplatform_sdk_embedding]
20+
21+
import com.google.cloud.aiplatform.util.ValueConverter;
22+
import com.google.cloud.aiplatform.v1beta1.EndpointName;
23+
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
26+
import com.google.protobuf.Value;
27+
import com.google.protobuf.util.JsonFormat;
28+
import java.io.IOException;
29+
import java.util.ArrayList;
30+
import java.util.List;
31+
32+
public class PredictTextEmbeddingsSample {
33+
34+
public static void main(String[] args) throws IOException {
35+
// TODO(developer): Replace these variables before running the sample.
36+
// Details about text embedding request structure and supported models are available in:
37+
// https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings
38+
String instance = "{ \"content\": \"What is life?\"}";
39+
String project = "YOUR_PROJECT_ID";
40+
String location = "us-central1";
41+
String publisher = "google";
42+
String model = "textembedding-gecko@001";
43+
44+
predictTextEmbeddings(instance, project, location, publisher, model);
45+
}
46+
47+
// Get text embeddings from a supported embedding model
48+
public static void predictTextEmbeddings(
49+
String instance, String project, String location, String publisher, String model)
50+
throws IOException {
51+
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
52+
PredictionServiceSettings predictionServiceSettings =
53+
PredictionServiceSettings.newBuilder()
54+
.setEndpoint(endpoint)
55+
.build();
56+
57+
// Initialize client that will be used to send requests. This client only needs to be created
58+
// once, and can be reused for multiple requests.
59+
try (PredictionServiceClient predictionServiceClient =
60+
PredictionServiceClient.create(predictionServiceSettings)) {
61+
EndpointName endpointName =
62+
EndpointName.ofProjectLocationPublisherModelName(project, location, publisher, model);
63+
64+
// Use Value.Builder to convert instance to a dynamically typed value that can be
65+
// processed by the service.
66+
Value.Builder instanceValue = Value.newBuilder();
67+
JsonFormat.parser().merge(instance, instanceValue);
68+
List<Value> instances = new ArrayList<>();
69+
instances.add(instanceValue.build());
70+
71+
PredictResponse predictResponse =
72+
predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE);
73+
System.out.println("Predict Response");
74+
for (Value prediction : predictResponse.getPredictionsList()) {
75+
System.out.format("\tPrediction: %s\n", prediction);
76+
}
77+
}
78+
}
79+
}
80+
// [END aiplatform_sdk_embedding]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright 2023 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import com.google.cloud.testing.junit4.MultipleAttemptsRule;
23+
import java.io.ByteArrayOutputStream;
24+
import java.io.IOException;
25+
import java.io.PrintStream;
26+
import org.junit.After;
27+
import org.junit.Before;
28+
import org.junit.BeforeClass;
29+
import org.junit.Rule;
30+
import org.junit.Test;
31+
32+
public class PredictTextEmbeddingsSampleTest {
33+
34+
@Rule public final MultipleAttemptsRule multipleAttemptsRule = new MultipleAttemptsRule(3);
35+
36+
private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
37+
private static final String LOCATION = "us-central1";
38+
private static final String INSTANCE = "{ \"content\": \"What is life?\"}";
39+
private static final String PUBLISHER = "google";
40+
private static final String MODEL = "textembedding-gecko@001";
41+
42+
private ByteArrayOutputStream bout;
43+
private PrintStream out;
44+
private PrintStream originalPrintStream;
45+
46+
private static void requireEnvVar(String varName) {
47+
String errorMessage =
48+
String.format("Environment variable '%s' is required to perform these tests.", varName);
49+
assertNotNull(errorMessage, System.getenv(varName));
50+
}
51+
52+
@BeforeClass
53+
public static void checkRequirements() {
54+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
55+
requireEnvVar("UCAIP_PROJECT_ID");
56+
}
57+
58+
@Before
59+
public void setUp() {
60+
bout = new ByteArrayOutputStream();
61+
out = new PrintStream(bout);
62+
originalPrintStream = System.out;
63+
System.setOut(out);
64+
}
65+
66+
@After
67+
public void tearDown() {
68+
System.out.flush();
69+
System.setOut(originalPrintStream);
70+
}
71+
72+
@Test
73+
public void testPredictTextEmbeddings() throws IOException {
74+
// Act
75+
PredictTextEmbeddingsSample.predictTextEmbeddings(
76+
INSTANCE, PROJECT, LOCATION, PUBLISHER, MODEL);
77+
78+
// Assert
79+
String got = bout.toString();
80+
assertThat(got).contains("Predict Response");
81+
}
82+
}

0 commit comments

Comments
 (0)
0