8000 feat(aiplatform): add mask-free editing Imagen code sample and test (… · suztomo/java-docs-samples@a6f99be · GitHub
[go: up one dir, main page]

Skip to content

Commit a6f99be

Browse files
authored
feat(aiplatform): add mask-free editing Imagen code sample and test (GoogleCloudPlatform#9442)
* feat(aiplatform): add mask-free editing Imagen code sample and test * address feedback: add proper assert statements and remove unnecessary sysout streams * break early in tests
1 parent 937aaa2 commit a6f99be

File tree

5 files changed

+195
-22
lines changed

5 files changed

+195
-22
lines changed

aiplatform/resources/cat.png

940 KB
Loading
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
// [START generativeaionvertexai_imagen_edit_image_mask_free]
20+
21+
import com.google.api.gax.rpc.ApiException;
22+
import com.google.cloud.aiplatform.v1.EndpointName;
23+
import com.google.cloud.aiplatform.v1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
26+
import com.google.gson.Gson;
27+
import com.google.protobuf.InvalidProtocolBufferException;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.io.IOException;
31+
import java.nio.charset.StandardCharsets;
32+
import java.nio.file.Files;
33+
import java.nio.file.Path;
34+
import java.nio.file.Paths;
35+
import java.util.Base64;
36+
import java.util.Collections;
37+
import java.util.HashMap;
38+
import java.util.Map;
39+
40+
public class EditImageMaskFreeSample {
41+
42+
public static void main(String[] args) throws IOException {
43+
// TODO(developer): Replace these variables before running the sample.
44+
String projectId = "my-project-id";
45+
String location = "us-central1";
46+
String inputPath = "/path/to/my-input.png";
47+
String prompt = ""; // The text prompt describing what you want to see.
48+
49+
editImageMaskFree(projectId, location, inputPath, prompt);
50+
}
51+
52+
// Edit an image without using a mask. The edit is applied to the entire image and is saved to a
53+
// new file.
54+
public static PredictResponse editImageMaskFree(
55+
String projectId, String location, String inputPath, String prompt)
56+
throws ApiException, IOException {
57+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
58+
PredictionServiceSettings predictionServiceSettings =
59+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
60+
61+
// Initialize client that will be used to send requests. This client only needs to be created
62+
// once, and can be reused for multiple requests.
63+
try (PredictionServiceClient predictionServiceClient =
64+
PredictionServiceClient.create(predictionServiceSettings)) {
65+
66+
final EndpointName endpointName =
67+
EndpointName.ofProjectLocationPublisherModelName(
68+
projectId, location, "google", "imagegeneration@002");
69+
70+
// Convert the image to Base64.
71+
byte[] imageData = Base64.getEncoder().encode(Files.readAllBytes(Paths.get(inputPath)));
72+
String image = new String(imageData, StandardCharsets.UTF_8);
73+
Map<String, String> imageMap = new HashMap<>();
74+
imageMap.put("bytesBase64Encoded", image);
75+
76+
Map<String, Object> instancesMap = new HashMap<>();
77+
instancesMap.put("prompt", prompt);
78+
instancesMap.put("image", imageMap);
79+
Value instances = mapToValue(instancesMap);
80+
81+
Map<String, Object> paramsMap = new HashMap<>();
82+
// Optional parameters
83+
paramsMap.put("seed", 1);
84+
// Controls the strength of the prompt.
85+
// 0-9 (low strength), 10-20 (medium strength), 21+ (high strength)
86+
paramsMap.put("guidanceScale", 21);
87+
paramsMap.put("sampleCount", 1);
88+
Value parameters = mapToValue(paramsMap);
89+
90+
PredictResponse predictResponse =
91+
predictionServiceClient.predict(
92+
endpointName, Collections.singletonList(instances), parameters);
93+
94+
for (Value prediction : predictResponse.getPredictionsList()) {
95+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
96+
if (fieldsMap.containsKey("bytesBase64Encoded")) {
97+
String bytesBase64Encoded = fieldsMap.get("bytesBase64Encoded").getStringValue();
98+
Path tmpPath = Files.createTempFile("imagen-", ".png");
99+
Files.write(tmpPath, Base64.getDecoder().decode(bytesBase64Encoded));
100+
System.out.format("Image file written to: %s\n", tmpPath.toUri());
101+
}
102+
}
103+
return predictResponse;
104+
}
105+
}
106+
107+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
108+
Gson gson = new Gson();
109+
String json = gson.toJson(map);
110+
Value.Builder builder = Value.newBuilder();
111+
JsonFormat.parser().merge(json, builder);
112+
return builder.build();
113+
}
114+
}
115+
116+
// [END generativeaionvertexai_imagen_edit_image_mask_free]

aiplatform/src/main/java/aiplatform/imagen/GenerateImageSample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public static PredictResponse generateImage(String projectId, String location, S
7070
paramsMap.put("sampleCount", 1);
7171
// You can't use a seed value and watermark at the same time.
7272
// paramsMap.put("seed", 100);
73-
// paramsMap.put("addWatermark", true);
73+
// paramsMap.put("addWatermark", false);
7474
paramsMap.put("aspectRatio", "1:1");
7575
paramsMap.put("safetyFilterLevel", "block_some");
7676
paramsMap.put("personGeneration", "allow_adult");
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import com.google.cloud.aiplatform.v1.PredictResponse;
23+
import com.google.protobuf.Value;
24+
import java.io.IOException;
25+
import java.util.Map;
26+
import org.junit.BeforeClass;
27+
import org.junit.Test;
28+
import org.junit.runner.RunWith;
29+
import org.junit.runners.JUnit4;
30+
31+
@RunWith(JUnit4.class)
32+
public class EditImageMaskFreeSampleTest {
33+
34+
private static final String PROJECT = System.getenv("GOOGLE_CLOUD_PROJECT");
35+
private static final String INPUT_FILE = "resources/cat.png";
36+
private static final String PROMPT = "a dog";
37+
38+
private static void requireEnvVar(String varName) {
39+
String errorMessage =
40+
String.format("Environment variable '%s' is required to perform these tests.", varName);
41+
assertNotNull(errorMessage, System.getenv(varName));
42+
}
43+
44+
@BeforeClass
45+
public static void checkRequirements() {
46+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
47+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
48+
}
49+
50+
@Test
51+
public void testEditImageMaskFreeSample() throws IOException {
52+
PredictResponse response =
53+
EditImageMaskFreeSample.editImageMaskFree(PROJECT, "us-central1", INPUT_FILE, PROMPT);
54+
assertThat(response).isNotNull();
55+
56+
Boolean imageBytes = false;
57+
for (Value prediction : response.getPredictionsList()) {
58+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
59+
if (fieldsMap.containsKey("bytesBase64Encoded")) {
60+
imageBytes = true;
61+
break;
62+
}
63+
}
64+
assertThat(imageBytes).isTrue();
65+
}
66+
}

aiplatform/src/test/java/aiplatform/imagen/GenerateImageSampleTest.java

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020
import static junit.framework.TestCase.assertNotNull;
2121

2222
import com.google.cloud.aiplatform.v1.PredictResponse;
23-
import java.io.ByteArrayOutputStream;
23+
import com.google.protobuf.Value;
2424
import java.io.IOException;
25-
import java.io.PrintStream;
26-
import org.junit.After;
27-
import org.junit.Before;
25+
import java.util.Map;
2826
import org.junit.BeforeClass;
2927
import org.junit.Test;
3028
import org.junit.runner.RunWith;
@@ -35,9 +33,6 @@ public class GenerateImageSampleTest {
3533

3634
private static final String PROJECT = System.getenv("GOOGLE_CLOUD_PROJECT");
3735
private static final String PROMPT = "a dog reading a newspaper";
38-
private ByteArrayOutputStream bout;
39-
private PrintStream out;
40-
private PrintStream originalPrintStream;
4136

4237
private static void requireEnvVar(String varName) {
4338
String errorMessage =
@@ -51,23 +46,19 @@ public static void checkRequirements() {
5146
requireEnvVar("GOOGLE_CLOUD_PROJECT");
5247
}
5348

54-
@Before
55-
public void setUp() {
56-
bout = new ByteArrayOutputStream();
57-
out = new PrintStream(bout);
58-
originalPrintStream = System.out;
59-
System.setOut(out);
60-
}
61-
62-
@After
63-
public void tearDown() {
64-
System.out.flush();
65-
System.setOut(originalPrintStream);
66-
}
67-
6849
@Test
6950
public void testGenerateImageSample() throws IOException {
7051
PredictResponse response = GenerateImageSample.generateImage(PROJECT, "us-central1", PROMPT);
7152
assertThat(response).isNotNull();
53+
54+
Boolean imageBytes = false;
55+
for (Value prediction : response.getPredictionsList()) {
56+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
57+
if (fieldsMap.containsKey("bytesBase64Encoded")) {
58+
imageBytes = true;
59+
break;
60+
}
61+
}
62+
assertThat(imageBytes).isTrue();
7263
}
7364
}

0 commit comments

Comments
 (0)
0