8000 feat(aiplatform): add generativeaionvertexai embedding batch sample (… · ludoch/java-docs-samples@d53570d · GitHub
[go: up one dir, main page]

Skip to content

Commit d53570d

Browse files
feat(aiplatform): add generativeaionvertexai embedding batch sample (GoogleCloudPlatform#9559)
* Implemented generativeaionvertexai_embedding_batch sample, created test * Enabled CreateTrainingPipelineTextClassificationSampleTest which had error * fixed according to commits
1 parent 4697d0d commit d53570d

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START generativeaionvertexai_embedding_batch]
20+
21+
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
22+
import com.google.cloud.aiplatform.v1.GcsDestination;
23+
import com.google.cloud.aiplatform.v1.GcsSource;
24+
import com.google.cloud.aiplatform.v1.JobServiceClient;
25+
import com.google.cloud.aiplatform.v1.JobServiceSettings;
26+
import com.google.cloud.aiplatform.v1.LocationName;
27+
import java.io.IOException;
28+
29+
public class EmbeddingBatchSample {
30+
31+
public static void main(String[] args) throws IOException, InterruptedException {
32+
// TODO(developer): Replace these variables before running the sample.
33+
String project = "YOUR_PROJECT_ID";
34+
String location = "us-central1";
35+
// inputUri: URI of the input dataset.
36+
// Could be a BigQuery table or a Google Cloud Storage file.
37+
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
38+
String inputUri = "gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl";
39+
// outputUri: URI where the output will be stored.
40+
// Could be a BigQuery table or a Google Cloud Storage file.
41+
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
42+
String outputUri = "gs://YOUR_BUCKET/embedding_batch_output";
43+
String textEmbeddingModel = "textembedding-gecko@003";
44+
45+
embeddingBatchSample(project, location, inputUri, outputUri, textEmbeddingModel);
46+
}
47+
48+
// Generates embeddings from text using batch processing.
49+
// Read more: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/batch-prediction-genai-embeddings
50+
public static BatchPredictionJob embeddingBatchSample(
51+
String project, String location, String inputUri, String outputUri, String textEmbeddingModel)
52+
throws IOException {
53+
BatchPredictionJob response;
54+
JobServiceSettings jobServiceSettings = JobServiceSettings.newBuilder()
55+
.setEndpoint("us-central1-aiplatform.googleapis.com:443").build();
56+
LocationName parent = LocationName.of(project, location);
57+
String modelName = String.format("projects/%s/locations/%s/publishers/google/models/%s",
58+
project, location, textEmbeddingModel);
59+
60+
// Initialize client that will be used to send requests. This client only needs to be created
61+
// once, and can be reused for multiple requests.
62+
try (JobServiceClient client = JobServiceClient.create(jobServiceSettings)) {
63+
BatchPredictionJob batchPredictionJob =
64+
BatchPredictionJob.newBuilder()
65+
.setDisplayName("my embedding batch job " + System.currentTimeMillis())
66+
.setModel(modelName)
67+
.setInputConfig(
68+
BatchPredictionJob.InputConfig.newBuilder()
69+
.setGcsSource(GcsSource.newBuilder().addUris(inputUri).build())
70+
.setInstancesFormat("jsonl")
71+
.build())
72+
.setOutputConfig(
73+
BatchPredictionJob.OutputConfig.newBuilder()
74+
.setGcsDestination(GcsDestination.newBuilder()
75+
.setOutputUriPrefix(outputUri).build())
76+
.setPredictionsFormat("jsonl")
77+
.build())
78+
.build();
79+
80+
response = client.createBatchPredictionJob(parent, batchPredictionJob);
81+
82+
System.out.format("response: %s\n", response);
83+
System.out.format("\tName: %s\n", response.getName());
84+
}
85+
return response;
86+
}
87+
}
88+
// [END generativeaionvertexai_embedding_batch]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
* 628C
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
20+
import com.google.cloud.storage.Bucket;
21+
import com.google.cloud.storage.BucketInfo;
22+
import com.google.cloud.storage.Storage;
23+
import com.google.cloud.storage.StorageOptions;
24+
import java.io.IOException;
25+
import java.util.UUID;
26+
import junit.framework.TestCase;
27+
import org.junit.AfterClass;
28+
import org.junit.BeforeClass;
29+
import org.junit.Test;
30+
import org.junit.jupiter.api.Assertions;
31+
import org.junit.runner.RunWith;
32+
import org.junit.runners.JUnit4;
33+
34+
@RunWith(JUnit4.class)
35+
public class EmbeddingBatchSampleTest extends TestCase {
36+
37+
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
38+
private static final String LOCATION = "us-central1";
39+
private static String BUCKET_NAME;
40+
private static final String GCS_SOURCE_URI =
41+
"gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl";
42+
private static final String GCS_OUTPUT_URI =
43+
String.format("gs://%s/embedding_batch_output", BUCKET_NAME);
44+
private static final String MODEL_ID = "textembedding-gecko@003";
45+
static Storage storage;
46+
static Bucket bucket;
47+
48+
private static void requireEnvVar(String varName) {
49+
String errorMessage =
50+
String.format("Environment variable '%s' is required to perform these tests.", varName);
51+
assertNotNull(errorMessage, System.getenv(varName));
52+
}
53+
54+
@BeforeClass
55+
public static void checkRequirements() throws IOException {
56+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
57+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
58+
BUCKET_NAME = "my-new-test-bucket" + UUID.randomUUID();
59+
60+
// Create a Google Cloud Storage bucket for UsageReports
61+
storage = StorageOptions.newBuilder().setProjectId(PROJECT_ID).build().getService();
62+
storage.create(BucketInfo.of(BUCKET_NAME));
63+
}
64+
65+
@AfterClass
66+
public static void afterClass() {
67+
// Delete the Google Cloud Storage bucket created for usage reports.
68+
storage = StorageOptions.newBuilder().setProjectId(PROJECT_ID).build().getService();
69+
bucket = storage.get(BUCKET_NAME);
70+
bucket.delete();
71+
}
72+
73+
@Test
74+
public void testEmbeddingBatchSample() throws IOException {
75+
76+
BatchPredictionJob batchPredictionJob =
77+
EmbeddingBatchSample.embeddingBatchSample(PROJECT_ID, LOCATION, GCS_SOURCE_URI,
78+
GCS_OUTPUT_URI, MODEL_ID);
79+
80+
Assertions.assertNotNull(batchPredictionJob);
81+
assertTrue(batchPredictionJob.getDisplayName().contains("my embedding batch job "));
82+
assertTrue(batchPredictionJob.getModel()
83+
.contains("publishers/google/models/textembedding-gecko"));
84+
}
85+
}

0 commit comments

Comments
 (0)
0