8000 feat(aiplatform): add batch text predict sample (#9520) · suztomo/java-docs-samples@5cb4201 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5cb4201

Browse files
feat(aiplatform): add batch text predict sample (GoogleCloudPlatform#9520)
* Implemented aiplatform_batch_text_predict sample, created test * Fixed imports * Fixed code according to comments for aiplatform_batch_code_predict sample * Fixed comment * Disabled test. Sample should be migrated to Vertex AI * Fixed po file
1 parent ba36764 commit 5cb4201

File tree

4 files changed

+216
-0
lines changed

4 files changed

+216
-0
lines changed

aiplatform/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,11 @@
8989
<version>1.7.1</version>
9090
<scope>test</scope>
9191
</dependency>
92+
<dependency>
93+
<groupId>org.junit.jupiter</groupId>
94+
<artifactId>junit-jupiter</artifactId>
95+
<version>RELEASE</version>
96+
<scope>test</scope>
97+
</dependency>
9298
</dependencies>
9399
</project>
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START aiplatform_batch_text_predict]
20+
21+
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
22+
import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig;
23+
import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig;
24+
import com.google.cloud.aiplatform.v1.GcsDestination;
25+
import com.google.cloud.aiplatform.v1.GcsSource;
26+
import com.google.cloud.aiplatform.v1.JobServiceClient;
27+
import com.google.cloud.aiplatform.v1.JobServiceSettings;
28+
import com.google.cloud.aiplatform.v1.LocationName;
29+
import com.google.protobuf.InvalidProtocolBufferException;
30+
import com.google.protobuf.Value;
31+
import com.google.protobuf.util.JsonFormat;
32+
import java.io.IOException;
33+
34+
35+
public class BatchTextPredictionSample {
36+
37+
public static void main(String[] args) throws IOException {
38+
// TODO (Developer): Replace the input_uri and output_uri with your own GCS paths
39+
String project = "YOUR_PROJECT_ID";
40+
String location = "us-central1";
41+
// inputUri (str, optional): URI of the input dataset.
42+
// Could be a BigQuery table or a Google Cloud Storage file.
43+
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
44+
String inputUri = "gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl";
45+
// outputUri (str, optional): URI where the output will be stored.
46+
// Could be a BigQuery table or a Google Cloud Storage file.
47+
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
48+
String outputUri = "gs://batch-bucket-testing/batch_text_predict_output";
49+
String codeModel = "text-bison";
50+
51+
batchTextPrediction(project, location, inputUri, outputUri, codeModel);
52+
}
53+
54+
// Perform batch text prediction using a pre-trained text generation model.
55+
// Example of using Google Cloud Storage bucket as the input and output data source
56+
public static void batchTextPrediction(
57+
String project, String location, String inputUri,
58+
String outputUri, String codeModel) throws IOException {
59+
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
60+
JobServiceSettings jobServiceSettings =
61+
JobServiceSettings.newBuilder().setEndpoint(endpoint).build();
62+
// Construct your modelParameters
63+
String parameters =
64+
"{\n" + " \"temperature\": 0.2,\n" + " \"maxOutputTokens\": 200\n" + "}";
65+
Value parameterValue = stringToValue(parameters);
66+
String modelName = String.format(
67+
"projects/%s/locations/%s/publishers/google/models/%s", project, location, codeModel);
68+
69+
// Initialize client that will be used to send requests. This client only needs to be created
70+
// once, and can be reused for multiple requests.
71+
try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {
72+
73+
GcsSource.Builder gcsSource = GcsSource.newBuilder();
74+
gcsSource.addUris(inputUri);
75+
InputConfig inputConfig =
76+
InputConfig.newBuilder()
77+
.setGcsSource(gcsSource)
78+
.setInstancesFormat("jsonl")
79+
.build();
80+
81+
GcsDestination.Builder gcsDestination = GcsDestination.newBuilder();
82+
gcsDestination.setOutputUriPrefix(outputUri);
83+
OutputConfig outputConfig =
84+
OutputConfig.newBuilder()
85+
.setGcsDestination(gcsDestination)
86+
.setPredictionsFormat("jsonl")
87+
.build();
88+
89+
BatchPredictionJob.Builder batchPredictionJob =
90+
BatchPredictionJob.newBuilder()
91+
.setDisplayName("my batch text prediction job " + System.currentTimeMillis())
92+
.setModel(modelName)
93+
.setInputConfig(inputConfig)
94+
.setOutputConfig(outputConfig)
95+
.setModelParameters(parameterValue);
96+
97+
LocationName parent = LocationName.of(project, location);
98+
BatchPredictionJob response =
99+
jobServiceClient.createBatchPredictionJob(parent, batchPredictionJob.build());
100+
101+
System.out.format("response: %s\n", response);
102+
System.out.format("\tName: %s\n", response.getName());
103+
}
104+
}
105+
106+
// Convert a Json string to a protobuf.Value
107+
static Value stringToValue(String value) throws InvalidProtocolBufferException {
108+
Value.Builder builder = Value.newBuilder();
109+
JsonFormat.parser().merge(value, builder);
110+
return builder.build();
111+
}
112+
}
113+
// [END aiplatform_batch_text_predict]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static org.junit.Assert.assertNotNull;
21+
22+
import com.google.cloud.storage.Bucket;
23+
import com.google.cloud.storage.BucketInfo;
24+
import com.google.cloud.storage.Storage;
25+
import com.google.cloud.storage.StorageOptions;
26+
import java.io.ByteArrayOutputStream;
27+
import java.io.IOException;
28+
import java.io.PrintStream;
29+
imp 10000 ort java.util.UUID;
30+
import org.junit.After;
31+
import org.junit.AfterClass;
32+
import org.junit.Before;
33+
import org.junit.BeforeClass;
34+
import org.junit.Test;
35+
import org.junit.runner.RunWith;
36+
import org.junit.runners.JUnit4;
37+
38+
@RunWith(JUnit4.class)
39+
public class BatchTextPredictionSampleTest {
40+
41+
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
42+
private static final String LOCATION = "us-central1";
43+
private static String BUCKET_NAME;
44+
private static final String GCS_SOURCE_URI =
45+
"gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl";
46+
private static final String GCS_DESTINATION_OUTPUT_PREFIX =
47+
String.format("gs://%s/ucaip-test-output/", BUCKET_NAME);
48+
private static final String MODEL_ID = "text-bison";
49+
private ByteArrayOutputStream stdOut;
50+
51+
private static void requireEnvVar(String varName) {
52+
String errorMessage =
53+
String.format("Environment variable '%s' is required to perform these tests.", varName);
54+
assertNotNull(errorMessage, System.getenv(varName));
55+
}
56+
57+
@BeforeClass
58+
public static void checkRequirements() throws IOException {
59+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
60+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
61+
BUCKET_NAME = "my-new-test-bucket" + UUID.randomUUID();
62+
63+
// Create a Google Cloud Storage bucket for UsageReports
64+
Storage storage = StorageOptions.newBuilder().setProjectId(PROJECT_ID).build().getService();
65+
storage.create(BucketInfo.of(BUCKET_NAME));
66+
}
67+
68+
@AfterClass
69+
public static void afterClass() {
70+
// Delete the Google Cloud Storage bucket created for usage reports.
71+
Storage storage = StorageOptions.newBuilder().setProjectId(PROJECT_ID).build().getService();
72+
Bucket bucket = storage.get(BUCKET_NAME);
73+
bucket.delete();
74+
}
75+
76+
@Before
77+
public void beforeEach() {
78+
stdOut = new ByteArrayOutputStream();
79+
System.setOut(new PrintStream(stdOut));
80+
}
81+
82+
@After
83+
public void afterEach() {
84+
stdOut = null;
85+
System.setOut(null);
86+
}
87+
88+
@Test
89+
public void testBatchTextPredictionSample() throws IOException {
90+
BatchTextPredictionSample.batchTextPrediction(PROJECT_ID, LOCATION, GCS_SOURCE_URI,
91+
GCS_DESTINATION_OUTPUT_PREFIX, MODEL_ID);
92+
assertThat(stdOut.toString()).contains("publishers/google/models/text-bison");
93+
assertThat(stdOut.toString()).contains("my batch text prediction job");
94+
}
95+
}

aiplatform/src/test/java/aiplatform/CreateTrainingPipelineTextClassificationSampleTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.junit.BeforeClass;
3333
import org.junit.Rule;
3434
import org.junit.Test;
35+
import org.junit.jupiter.api.Disabled;
3536
import org.junit.runner.RunWith;
3637
import org.junit.runners.JUnit4;
3738

@@ -88,6 +89,7 @@ public void tearDown()
8889
System.setOut(originalPrintStream);
8990
}
9091

92+
@Disabled // Sample should be migrated to Vertex AI https://softserve-jirasw.atlassian.net/browse/DSAM-173
9193
@Test
9294
public void testCreateTrainingPipelineTextClassificationSample() throws IOException {
9395
// Act

0 commit comments

Comments
 (0)
0