8000 feat(aiplatform): add batch code predict sample (#9519) · ludoch/java-docs-samples@a002e61 · GitHub
[go: up one dir, main page]

Skip to content

Commit a002e61

Browse files
feat(aiplatform): add batch code predict sample (GoogleCloudPlatform#9519)
* Implemented aiplatform_batch_code_predict, added test * Fixed test * Fixed comments and test * deleted redundant code * Fixed comment * Fixed version in pom * deleted redundant dependency * Fixed imports * Fixed variables in test * Fixed code according to comments * Added aiplatform_batch_text_predict sample, created test * Revert "Added aiplatform_batch_text_predict sample, created test" This reverts commit 603da7b. * Fixed comment * Disabled test. Sample should be migrated to Vertex AI * Changed tags * Fixed tag * Fixed test, added return response * Fixed HashMap for parameters, added return value, fixed tests * Fixed comment
1 parent 613df56 commit a002e61

File tree

4 files changed

+266
-78
lines changed

4 files changed

+266
-78
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START generativeaionvertexai_batch_code_predict]
20+
21+
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
22+
import com.google.cloud.aiplatform.v1.GcsDestination;
23+
import com.google.cloud.aiplatform.v1.GcsSource;
24+
import com.google.cloud.aiplatform.v1.JobServiceClient;
25+
import com.google.cloud.aiplatform.v1.JobServiceSettings;
26+
import com.google.cloud.aiplatform.v1.LocationName;
27+
import com.google.gson.Gson;
28+
import com.google.protobuf.InvalidProtocolBufferException;
29+
import com.google.protobuf.Value;
30+
import com.google.protobuf.util.JsonFormat;
31+
import java.io.IOException;
32+
import java.util.HashMap;
33+
import java.util.Map;
34+
35+
public class BatchCodePredictionSample {
36+
37+
public static void main(String[] args) throws IOException, InterruptedException {
38+
// TODO(developer): Replace these variables before running the sample.
39+
String project = "YOUR_PROJECT_ID";
40+
String location = "us-central1";
41+
// inputUri: URI of the input dataset.
42+
// Could be a BigQuery table or a Google Cloud Storage file.
43+
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
44+
String inputUri = "gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl";
45+
// outputUri: URI where the output will be stored.
46+
// Could be a BigQuery table or a Google Cloud Storage file.
47+
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
48+
String outputUri = "gs://YOUR_BUCKET/batch_code_predict_output";
49+
String codeModel = "code-bison";
50+
51+
batchCodePredictionSample(project, location, inputUri, outputUri, codeModel);
52+
}
53+
54+
// Perform batch code prediction using a pre-trained code generation model.
55+
// Example of using Google Cloud Storage bucket as the input and output data source
56+
public static BatchPredictionJob batchCodePredictionSample(
57+
String project, String location, String inputUri, String outputUri, String codeModel)
58+
throws IOException {
59+
BatchPredictionJob response;
60+
JobServiceSettings jobServiceSettings = JobServiceSettings.newBuilder()
61+
.setEndpoint("us-central1-aiplatform.googleapis.com:443").build();
62+
LocationName parent = LocationName.of(project, location);
63+
String modelName = String.format(
64+
"projects/%s/locations/%s/publishers/google/models/%s", project, location, codeModel);
65+
// Construct your modelParameters
66+
Map<String, String> modelParameters = new HashMap<>();
67+
modelParameters.put("maxOutputTokens", "200");
68+
modelParameters.put("temperature", "0.2");
69+
modelParameters.put("topP", "0.95");
70+
modelParameters.put("topK", "40");
71+
Value parameterValue = mapToValue(modelParameters);
72+
73+
// Initialize client that will be used to send requests. This client only needs to be created
74+
// once, and can be reused for multiple requests.
75+
try (JobServiceClient client = JobServiceClient.create(jobServiceSettings)) {
76+
BatchPredictionJob batchPredictionJob =
77+
BatchPredictionJob.newBuilder()
78+
.setDisplayName("my batch code prediction job " + System.currentTimeMillis())
79+
.setModel(modelName)
80+
.setInputConfig(
81+
BatchPredictionJob.InputConfig.newBuilder()
82+
.setGcsSource(GcsSource.newBuilder().addUris(inputUri).build())
83+
.setInstancesFormat("jsonl")
84+
.build())
85+
.setOutputConfig(
86+
BatchPredictionJob.OutputConfig.newBuilder()
87+
.setGcsDestination(GcsDestination.newBuilder()
88+
.setOutputUriPrefix(outputUri).build())
89+
.setPredictionsFormat("jsonl")
90+
.build())
91+
.setModelParameters(parameterValue)
92+
.build();
93+
94+
response = client.createBatchPredictionJob(parent, batchPredictionJob);
95+
96+
System.out.format("response: %s\n", response);
97+
System.out.format("\tName: %s\n", response.getName());
98+
}
99+
return response;
100+
}
101+
102+
private static Value mapToValue(Map<String, String> map) throws InvalidProtocolBufferException {
103+
Gson gson = new Gson();
104+
String json = gson.toJson(map);
105+
Value.Builder builder = Value.newBuilder();
106+
JsonFormat.parser().merge(json, builder);
107+
return builder.build();
108+
}
109+
}
110+
// [END generativeaionvertexai_batch_code_predict]

aiplatform/src/main/java/aiplatform/BatchTextPredictionSample.java

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
77
*
8-
* http://www.apache.org/licenses/LICENSE-2.0
8+
* http://www.apache.org/licenses/LICENSE-2.0
99
*
1010
* Unless required by applicable law or agreed to in writing, software
1111
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -16,98 +16,100 @@
1616

1717
package aiplatform;
1818

19-
// [START aiplatform_batch_text_predict]
19+
// [START generativeaionvertexai_batch_text_predict]
2020

2121
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
22-
import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig;
23-
import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig;
2422
import com.google.cloud.aiplatform.v1.GcsDestination;
2523
import com.google.cloud.aiplatform.v1.GcsSource;
2624
import com.google.cloud.aiplatform.v1.JobServiceClient;
2725
import com.google.cloud.aiplatform.v1.JobServiceSettings;
28-
import com.google.cloud.aiplatform.v1.LocationName;
26+
import com.google.gson.Gson;
2927
import com.google.protobuf.InvalidProtocolBufferException;
3028
import com.google.protobuf.Value;
3129
import com.google.protobuf.util.JsonFormat;
3230
import java.io.IOException;
33-
31+
import java.util.HashMap;
32+
import java.util.Map;
33+
import java.util.concurrent.ExecutionException;
34+
import java.util.concurrent.TimeoutException;
3435

3536
public class BatchTextPredictionSample {
3637

37-
public static void main(String[] args) throws IOException {
38-
// TODO (Developer): Replace the input_uri and output_uri with your own GCS paths
38+
public static void main(String[] args)
39+
throws IOException, InterruptedException, ExecutionException, TimeoutException {
40+
// TODO(developer): Replace these variables before running the sample.
3941
String project = "YOUR_PROJECT_ID";
4042
String location = "us-central1";
41-
// inputUri (str, optional): URI of the input dataset.
43+
// inputUri: URI of the input dataset.
4244
// Could be a BigQuery table or a Google Cloud Storage file.
4345
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
4446
String inputUri = "gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl";
45-
// outputUri (str, optional): URI where the output will be stored.
47+
// outputUri: URI where the output will be stored.
4648
// Could be a BigQuery table or a Google Cloud Storage file.
4749
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
48-
String outputUri = "gs://batch-bucket-testing/batch_text_predict_output";
49-
String codeModel = "text-bison";
50+
String outputUri = "gs://YOUR_BUCKET/batch_text_predict_output";
51+
String textModel = "text-bison";
5052

51-
batchTextPrediction(project, location, inputUri, outputUri, codeModel);
53+
batchTextPrediction(project, inputUri, outputUri, textModel, location);
5254
}
5355

5456
// Perform batch text prediction using a pre-trained text generation model.
5557
// Example of using Google Cloud Storage bucket as the input and output data source
56-
public static void batchTextPrediction(
57-
String project, String location, String inputUri,
58-
String outputUri, String codeModel) throws IOException {
59-
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
60-
JobServiceSettings jobServiceSettings =
61-
JobServiceSettings.newBuilder().setEndpoint(endpoint).build();
62-
// Construct your modelParameters
63-
String parameters =
64-
"{\n" + " \"temperature\": 0.2,\n" + " \"maxOutputTokens\": 200\n" + "}";
65-
Value parameterValue = stringToValue(parameters);
58+
static BatchPredictionJob batchTextPrediction(
59+
String projectId, String inputUri, String outputUri, String textModel, String location)
60+
throws IOException {
61+
BatchPredictionJob response;
62+
JobServiceSettings jobServiceSettings = JobServiceSettings.newBuilder()
63+
.setEndpoint("us-central1-aiplatform.googleapis.com:443").build();
64+
String parent = String.format("projects/%s/locations/%s", projectId, location);
6665
String modelName = String.format(
67-
"projects/%s/locations/%s/publishers/google/models/%s", project, location, codeModel);
66+
"projects/%s/locations/%s/publishers/google/models/%s", projectId, location, textModel);
67+
// Construct model parameters
68+
Map<String, String> modelParameters = new HashMap<>();
69+
modelParameters.put("maxOutputTokens", "200");
70+
modelParameters.put("temperature", "0.2");
71+
modelParameters.put("topP", "0.95");
72+
modelParameters.put("topK", "40");
73+
Value parameterValue = mapToValue(modelParameters);
6874

6975
// Initialize client that will be used to send requests. This client only needs to be created
7076
// once, and can be reused for multiple requests.
7177
try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {
7278

73-
GcsSource.Builder gcsSource = GcsSource.newBuilder();
74-
gcsSource.addUris(inputUri);
75-
InputConfig inputConfig =
76-
InputConfig.newBuilder()
77-
.setGcsSource(gcsSource)
78-
.setInstancesFormat("jsonl")
79-
.build();
80-
81-
GcsDestination.Builder gcsDestination = GcsDestination.newBuilder();
82-
gcsDestination.setOutputUriPrefix(outputUri);
F438 83-
OutputConfig outputConfig =
84-
OutputConfig.newBuilder()
85-
.setGcsDestination(gcsDestination)
86-
.setPredictionsFormat("jsonl")
87-
.build();
88-
89-
BatchPredictionJob.Builder batchPredictionJob =
79+
BatchPredictionJob batchPredictionJob =
9080
BatchPredictionJob.newBuilder()
9181
.setDisplayName("my batch text prediction job " + System.currentTimeMillis())
9282
.setModel(modelName)
93-
.setInputConfig(inputConfig)
94-
.setOutputConfig(outputConfig)
95-
.setModelParameters(parameterValue);
83+
.setInputConfig(
84+
BatchPredictionJob.InputConfig.newBuilder()
85+
.setGcsSource(GcsSource.newBuilder().addUris(inputUri).build())
86+
.setInstancesFormat("jsonl")
87+
.build())
88+
.setOutputConfig(
89+
BatchPredictionJob.OutputConfig.newBuilder()
90+
.setGcsDestination(GcsDestination.newBuilder()
91+
.setOutputUriPrefix(outputUri).build())
92+
.setPredictionsFormat("jsonl")
93+
.build())
94+
.setModelParameters(parameterValue)
95+
.build();
9696

97-
LocationName parent = LocationName.of(project, location);
98-
BatchPredictionJob response =
99-
jobServiceClient.createBatchPredictionJob(parent, batchPredictionJob.build());
97+
// Create the batch prediction job
98+
response =
99+
jobServiceClient.createBatchPredictionJob(parent, batchPredictionJob);
100100

101101
System.out.format("response: %s\n", response);
102102
System.out.format("\tName: %s\n", response.getName());
103103
}
104+
return response;
104105
}
105106

106-
// Convert a Json string to a protobuf.Value
107-
static Value stringToValue(String value) throws InvalidProtocolBufferException {
107+
private static Value mapToValue(Map<String, String> map) throws InvalidProtocolBufferException {
108+
Gson gson = new Gson();
109+
String json = gson.toJson(map);
108110
Value.Builder builder = Value.newBuilder();
109-
JsonFormat.parser().merge(value, builder);
111+
JsonFormat.parser().merge(json, builder);
110112
return builder.build();
111113
}
112114
}
113-
// [END aiplatform_batch_text_predict]
115+
// [END generativeaionvertexai_batch_text_predict]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import static junit.framework.TestCase.assertNotNull;
20+
import static org.junit.Assert.assertTrue;
21+
22+
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
23+
import com.google.cloud.storage.Bucket;
24+
import com.google.cloud.storage.BucketInfo;
25+
import com.google.cloud.storage.Storage;
26+
import com.google.cloud.storage.StorageOptions;
27+
import java.io.IOException;
28+
import java.util.UUID;
29+
import org.junit.AfterClass;
30+
import org.junit.BeforeClass;
31+
import org.junit.Test;
32+
import org.junit.jupiter.api.Assertions;
33+
import org.junit.runner.RunWith;
34+
import org.junit.runners.JUnit4;
35+
36+
@RunWith(JUnit4.class)
37+
public class BatchCodePredictionSampleTest {
38+
39+
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
40+
private static final String LOCATION = "us-central1";
41+
private static String BUCKET_NAME;
42+
private static final String GCS_SOURCE_URI =
43+
"gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl";
44+
private static final String GCS_DESTINATION_OUTPUT_PREFIX =
45+
String.format("gs://%s/batch-code-predict", BUCKET_NAME);
46+
private static final String MODEL_ID = "code-bison";
47+
static Storage storage;
48+
static Bucket bucket;
49+
50+
private static void requireEnvVar(String varName) {
51+
String errorMessage =
52+
String.format("Environment variable '%s' is required to perform these tests.", varName);
53+
assertNotNull(errorMessage, System.getenv(varName));
54+
}
55+
56+
@BeforeClass
57+
public static void checkRequirements() throws IOException {
58+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
59+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
60+
BUCKET_NAME = "my-new-test-bucket" + UUID.randomUUID();
61+
62+
// Create a Google Cloud Storage bucket for UsageReports
63+
storage = StorageOptions.newBuilder().setProjectId(PROJECT_ID).build().getService();
64+
storage.create(BucketInfo.of(BUCKET_NAME));
65+
}
66+
67+
@AfterClass
68+
public static void afterClass() {
69+
// Delete the Google Cloud Storage bucket created for usage reports.
70+
storage = StorageOptions.newBuilder().setProjectId(PROJECT_ID).build().getService();
71+
bucket = storage.get(BUCKET_NAME);
72+
bucket.delete();
73+
}
74+
75+
@Test
76+
public void testBatchCodePredictionSample() throws IOException {
77+
78+
BatchPredictionJob batchPredictionJob =
79+
BatchCodePredictionSample.batchCodePredictionSample(PROJECT_ID, LOCATION, GCS_SOURCE_URI,
80+
GCS_DESTINATION_OUTPUT_PREFIX, MODEL_ID);
81+
82+
Assertions.assertNotNull(batchPredictionJob);
83+
assertTrue(batchPredictionJob.getDisplayName().contains("my batch code prediction job"));
84+
assertTrue(batchPredictionJob.getModel().contains("publishers/google/models/code-bison"));
85+
}
86+
}

0 commit comments

Comments
 (0)
0