8000 feat(firebaseai): add imagen safetysetting attributes (#17707) · firebase/flutterfire@f7070f0 · GitHub
[go: up one dir, main page]

Skip to content

Commit f7070f0

Browse files
authored
feat(firebaseai): add imagen safetysetting attributes (#17707)
* add safety parameters * Add model test * add imagen editing test * Make safetySetting go through the toJson inside the class * fix imagen test
1 parent 73fffaf commit f7070f0

File tree

4 files changed

+299
-23
lines changed

4 files changed

+299
-23
lines changed

packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_api.dart

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ final class ImagenSafetySettings {
108108
final ImagenPersonFilterLevel? personFilterLevel;
109109

110110
// ignore: public_member_api_docs
111-
Object toJson() => {
111+
Map<String, Object?> toJson() => {
112112
if (safetyFilterLevel != null)
113113
'safetySetting': safetyFilterLevel!.toJson(),
114114
if (personFilterLevel != null)
@@ -194,7 +194,7 @@ final class ImagenGenerationConfig {
194194
// ignore: public_member_api_docs
195195
Map<String, dynamic> toJson() => {
196196
if (negativePrompt != null) 'negativePrompt': negativePrompt,
197-
if (numberOfImages != null) 'numberOfImages': numberOfImages,
197+
'sampleCount': numberOfImages ?? 1,
198198
if (aspectRatio != null) 'aspectRatio': aspectRatio!.toJson(),
199199
if (addWatermark != null) 'addWatermark': addWatermark,
200200
if (imageFormat != null) 'outputOptions': imageFormat!.toJson(),

packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_model.dart

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,17 @@ final class ImagenModel extends BaseApiClientModel {
6161
if (gcsUri != null) 'storageUri': gcsUri,
6262
'sampleCount': _generationConfig?.numberOfImages ?? 1,
6363
if (_generationConfig?.aspectRatio case final aspectRatio?)
64-
'aspectRatio': aspectRatio,
64+
'aspectRatio': aspectRatio.toJson(),
6565
if (_generationConfig?.negativePrompt case final negativePrompt?)
6666
'negativePrompt': negativePrompt,
6767
if (_generationConfig?.addWatermark case final addWatermark?)
6868
'addWatermark': addWatermark,
6969
if (_generationConfig?.imageFormat case final imageFormat?)
7070
'outputOption': imageFormat.toJson(),
71-
if (_safetySettings?.personFilterLevel case final personFilterLevel?)
72-
'personGeneration': personFilterLevel.toJson(),
73-
if (_safetySettings?.safetyFilterLevel case final safetyFilterLevel?)
74-
'safetySetting': safetyFilterLevel.toJson(),
71+
if (_safetySettings case final safetySettings?)
72+
...safetySettings.toJson(),
73+
'includeRaiReason': true,
74+
'includeSafetyAttributes': true,
7575
};
7676

7777
return {
@@ -170,10 +170,10 @@ final class ImagenModel extends BaseApiClientModel {
170170
'addWatermark': addWatermark,
171171
if (_generationConfig?.imageFormat case final imageFormat?)
172172
'outputOption': imageFormat.toJson(),
173-
if (_safetySettings?.personFilterLevel case final personFilterLevel?)
174-
'personGeneration': personFilterLevel.toJson(),
175-
if (_safetySettings?.safetyFilterLevel case final safetyFilterLevel?)
176-
'safetySetting': safetyFilterLevel.toJson(),
173+
if (_safetySettings case final safetySettings?)
174+
...safetySettings.toJson(),
175+
'includeRaiReason': true,
176+
'includeSafetyAttributes': true,
177177
};
178178

179179
return {
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import 'dart:typed_data';
16+
17+
import 'package:firebase_ai/firebase_ai.dart';
18+
import 'package:flutter_test/flutter_test.dart';
19+
20+
// Copied from imagen_model.dart for testing purposes as it is a private method.
21+
Map<String, Object?> generateImagenRequest(
22+
String prompt, {
23+
String? gcsUri,
24+
ImagenGenerationConfig? generationConfig,
25+
ImagenSafetySettings? safetySettings,
26+
}) {
27+
final parameters = <String, Object?>{
28+
if (gcsUri != null) 'storageUri': gcsUri,
29+
'sampleCount': generationConfig?.numberOfImages ?? 1,
30+
if (generationConfig?.aspectRatio case final aspectRatio?)
31+
'aspectRatio': aspectRatio.toJson(),
32+
if (generationConfig?.negativePrompt case final negativePrompt?)
33+
'negativePrompt': negativePrompt,
34+
if (generationConfig?.addWatermark case final addWatermark?)
35+
'addWatermark': addWatermark,
36+
if (generationConfig?.imageFormat case final imageFormat?)
37+
'outputOption': imageFormat.toJson(),
38+
if (safetySettings case final safetySettings?) ...safetySettings.toJson(),
39+
'includeRaiReason': true,
40+
'includeSafetyAttributes': true,
41+
};
42+
43+
return {
44+
'instances': [
45+
{'prompt': prompt}
46+
],
47+
'parameters': parameters,
48+
};
49+
}
50+
51+
// Copied from imagen_model.dart for testing
52+
Map<String, Object?> generateImagenEditRequest(
53+
List<ImagenReferenceImage> images,
54+
String prompt, {
55+
bool useVertexBackend = true, // Added for testing the throw
56+
ImagenEditingConfig? config,
57+
ImagenGenerationConfig? generationConfig,
58+
ImagenSafetySettings? safetySettings,
59+
}) {
60+
if (!useVertexBackend) {
61+
throw FirebaseAIException(
62+
'Image editing for Imagen is only supported on Vertex AI backend.');
63+
}
64+
final parameters = <String, Object?>{
65+
'sampleCount': generationConfig?.numberOfImages ?? 1,
66+
if (config?.editMode case final editMode?) 'editMode': editMode.toJson(),
67+
if (config?.editSteps case final editSteps?)
68+
'editConfig': {'baseSteps': editSteps},
69+
if (generationConfig?.negativePrompt case final negativePrompt?)
70+
'negativePrompt': negativePrompt,
71+
if (generationConfig?.addWatermark case final addWatermark?)
72+
'addWatermark': addWatermark,
73+
if (generationConfig?.imageFormat case final imageFormat?)
74+
'outputOption': imageFormat.toJson(),
75+
if (safetySettings case final safetySettings?) ...safetySettings.toJson(),
76+
'includeRaiReason': true,
77+
'includeSafetyAttributes': true,
78+
};
79+
80+
return {
81+
'parameters': parameters,
82+
'instances': [
83+
{
84+
'prompt': prompt,
85+
'referenceImages': images.asMap().entries.map((entry) {
86+
int index = entry.key;
87+
var image = entry.value;
88+
return image.toJson(referenceIdOverrideIfNull: index + images.length);
89+
}).toList(),
90+
}
91+
],
92+
};
93+
}
94+
95+
void main() {
96+
group('ImagenModel request generation', () {
97+
group('generateImagenRequest', () {
98+
test('creates a basic request with default parameters', () {
99+
final request = generateImagenRequest('a beautiful landscape');
100+
expect(request['instances'], [
101+
{'prompt': 'a beautiful landscape'}
102+
]);
103+
final params = request['parameters']! as Map<String, Object?>;
104+
expect(params['sampleCount'], 1);
105+
expect(params['includeRaiReason'], true);
106+
expect(params['includeSafetyAttributes'], true);
107+
expect(params.containsKey('storageUri'), isFalse);
108+
expect(params.containsKey('aspectRatio'), isFalse);
109+
expect(params.containsKey('negativePrompt'), isFalse);
110+
expect(params.containsKey('addWatermark'), isFalse);
111+
expect(params.containsKey('outputOption'), isFalse);
112+
expect(params.containsKey('personGeneration'), isFalse);
113+
expect(params.containsKey('safetySetting'), isFalse);
114+
});
115+
116+
test('includes all generation config parameters', () {
117+
final config = ImagenGenerationConfig(
118+
numberOfImages: 4,
119+
aspectRatio: ImagenAspectRatio.landscape16x9,
120+
negativePrompt: 'text, watermark',
121+
addWatermark: false,
122+
imageFormat: ImagenFormat.png(),
123+
);
124+
final request = generateImagenRequest('a futuristic city',
125+
generationConfig: config);
126+
final params = request['parameters']! as Map<String, Object?>;
127+
expect(params['sampleCount'], 4);
128+
expect(params['aspectRatio'], '16:9');
129+
expect(params['negativePrompt'], 'text, watermark');
130+
expect(params['addWatermark'], false);
131+
expect(params['outputOption'], {'mimeType': 'image/png'});
132+
expect(params['includeRaiReason'], true);
133+
expect(params['includeSafetyAttributes'], true);
134+
});
135+
136+
test('includes all safety settings parameters', () {
137+
final settings = ImagenSafetySettings(
138+
ImagenSafetyFilterLevel.blockNone,
139+
ImagenPersonFilterLevel.allowAdult,
140+
);
141+
final request =
142+
generateImagenRequest('a robot army', safetySettings: settings);
143+
final params = request['parameters']! as Map<String, Object?>;
144+
expect(params['personGeneration'], 'allow_adult');
145+
expect(params['safetySetting'], 'block_none');
146+
expect(params['includeRaiReason'], true);
147+
expect(params['includeSafetyAttributes'], true);
148+
});
149+
150+
test('includes gcsUri when provided', () {
151+
const uri = 'gs://my-test-bucket/image.png';
152+
final request = generateImagenRequest('a photo of a cat', gcsUri: uri);
153+
final params = request['parameters']! as Map<String, Object?>;
154+
expect(params['storageUri'], uri);
155+
expect(params['includeRaiReason'], true);
156+
expect(params['includeSafetyAttributes'], true);
157+
});
158+
159+
test('combines all parameters correctly', () {
160+
final config = ImagenGenerationConfig(
161+
numberOfImages: 2,
162+
negativePrompt: 'dark',
163+
);
164+
final settings = ImagenSafetySettings(
165+
ImagenSafetyFilterLevel.blockLowAndAbove,
166+
ImagenPersonFilterLevel.blockAll,
167+
);
168+
const uri = 'gs://my-test-bucket/output/';
169+
final request = generateImagenRequest(
170+
'a sunny beach',
171+
gcsUri: uri,
172+
generationConfig: config,
173+
safetySettings: settings,
174+
);
175+
176+
final params = request['parameters']! as Map<String, Object?>;
177+
expect(params['storageUri'], uri);
178+
expect(params['sampleCount'], 2);
179+
expect(params['negativePrompt'], 'dark');
180+
expect(params['safetySetting'], 'block_low_and_above');
181+
expect(params['includeRaiReason'], true);
182+
expect(params['includeSafetyAttributes'], true);
183+
expect(request['instances'], [
184+
{'prompt': 'a sunny beach'}
185+
]);
186+
});
187+
});
188+
189+
group('generateImagenEditRequest', () {
190+
late List<ImagenReferenceImage> referenceImages;
191+
192+
setUp(() {
193+
final dummyBytes = Uint8List.fromList([1, 2, 3]);
194+
final dummyInlineImage = ImagenInlineImage(
195+
bytesBase64Encoded: dummyBytes, mimeType: 'image/jpeg');
196+
referenceImages = [ImagenRawImage(image: dummyInlineImage)];
197+
});
198+
199+
test('creates a basic edit request', () {
200+
final request =
201+
generateImagenEditRequest(referenceImages, 'make it sunny');
202+
final params = request['parameters']! as Map<String, Object?>;
203+
expect(params['sampleCount'], 1);
204+
expect(params.containsKey('editMode'), isFalse);
205+
expect(params['includeRaiReason'], true);
206+
expect(params['includeSafetyAttributes'], true);
207+
208+
final instances = request['instances']! as List;
209+
expect(instances, hasLength(1));
210+
final instance = instances.first as Map<String, Object?>;
211+
expect(instance['prompt'], 'make it sunny');
212+
expect(instance['referenceImages'], isNotNull);
213+
});
214+
215+
test('does not include aspectRatio from generation config', () {
216+
final config = ImagenGenerationConfig(
217+
numberOfImages: 2, // This should be included as sampleCount
218+
aspectRatio: ImagenAspectRatio.square1x1, // This should be ignored
219+
);
220+
final request = generateImagenEditRequest(
221+
referenceImages,
222+
'add a rainbow',
223+
generationConfig: config,
224+
);
225+
final params = request['parameters']! as Map<String, Object?>;
226+
expect(params['sampleCount'], 2);
227+
expect(params.containsKey('aspectRatio'), isFalse,
228+
reason: 'aspectRatio is not a valid parameter for edit requests.');
229+
expect(params['includeRaiReason'], true);
230+
expect(params['includeSafetyAttributes'], true);
231+
});
232+
233+
test('includes other valid generation config values', () {
234+
final config = ImagenGenerationConfig(
235+
negativePrompt: 'rain',
236+
addWatermark: true,
237+
imageFormat: ImagenFormat.jpeg(),
238+
);
239+
final request = generateImagenEditRequest(
240+
referenceImages,
241+
'make it brighter',
242+
generationConfig: config,
243+
);
244+
final params = request['parameters']! as Map<String, Object?>;
245+
expect(params['negativePrompt'], 'rain');
246+
expect(params['addWatermark'], true);
247+
expect(params['outputOption'], {'mimeType': 'image/jpeg'});
248+
expect(params['includeRaiReason'], true);
249+
expect(params['includeSafetyAttributes'], true);
250+
});
251+
252+
test('includes editing config', () {
253+
final editConfig = ImagenEditingConfig(
254+
editMode: ImagenEditMode.inpaintInsertion,
255+
editSteps: 10,
256+
);
257+
final request = generateImagenEditRequest(
258+
referenceImages,
259+
'remove the background',
260+
config: editConfig,
261+
);
262+
final params = request['parameters']! as Map<String, Object?>;
263+
expect(params['editMode'], 'EDIT_MODE_INPAINT_INSERTION');
264+
expect(params['editConfig'], {'baseSteps': 10});
265+
expect(params['includeRaiReason'], true);
266+
expect(params['includeSafetyAttributes'], true);
267+
});
268+
269+
test('throws exception if not using Vertex backend', () {
270+
expect(
271+
() => generateImagenEditRequest(
272+
referenceImages,
273+
'a prompt',
274+
useVertexBackend: false,
275+
),
276+
throwsA(isA<FirebaseAIException>()),
277+
);
278+
});
279+
});
280+
});
281+
}

0 commit comments

Comments
 (0)
0