19
19
20
20
import numpy as np
21
21
22
- from transformers import BridgeTowerConfig , is_torch_available , is_vision_available
22
+ from transformers import (
23
+ BridgeTowerConfig ,
24
+ BridgeTowerTextConfig ,
25
+ BridgeTowerVisionConfig ,
26
+ is_torch_available ,
27
+ is_vision_available ,
28
+ )
23
29
from transformers .testing_utils import require_torch , require_vision , slow , torch_device
24
30
from transformers .utils import cached_property
25
31
54
60
from transformers import BridgeTowerProcessor
55
61
56
62
57
- class BridgeTowerModelTester :
63
+ class BridgeTowerTextModelTester :
58
64
def __init__ (
59
65
self ,
60
66
parent ,
61
- share_cross_modal_transformer_layers = True ,
62
- drop_rate = 0.1 ,
63
- head_hidden_scale = 2 ,
64
67
hidden_act = "gelu" ,
65
- hidden_size = 768 ,
68
+ hidden_size = 128 ,
66
69
initializer_factor = 1 ,
67
- is_encoder_decoder = False ,
68
70
layer_norm_eps = 1e-05 ,
69
- share_link_tower_layers = False ,
70
- link_tower_type = "add" ,
71
- num_attention_heads = 12 ,
72
- num_hidden_layers = 6 ,
71
+ num_attention_heads = 4 ,
72
+ num_hidden_layers = 2 ,
73
+ intermediate_size = 256 ,
73
74
tie_word_embeddings = False ,
74
- init_layernorm_from_vision_encoder = False ,
75
75
output_hidden_states = False ,
76
- text_config = None ,
77
- vision_config = None ,
78
- image_size = 288 ,
79
- contrastive_hidden_size = 512 ,
80
- logit_scale_init_value = 2.6592 ,
81
76
):
82
77
self .parent = parent
83
- self .share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
84
- self .drop_rate = drop_rate
85
- self .head_hidden_scale = head_hidden_scale
86
78
self .hidden_act = hidden_act
87
79
self .hidden_size = hidden_size
88
80
self .initializer_factor = initializer_factor
89
- self .is_encoder_decoder = is_encoder_decoder
90
81
self .layer_norm_eps = layer_norm_eps
91
- self .share_link_tower_layers = share_link_tower_layers
92
- self .link_tower_type = link_tower_type
93
82
self .num_attention_heads = num_attention_heads
94
83
self .num_hidden_layers = num_hidden_layers
84
+ self .intermediate_size = intermediate_size
95
85
self .tie_word_embeddings = tie_word_embeddings
96
- self .init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
97
86
self .vocab_size = 99
98
- self .num_channels = 3
99
87
self .seq_length = 4
100
- self .num_image_features = 325
101
88
self .batch_size = 1
102
- self .image_size = image_size
103
89
self .is_training = False
104
- self .expected_num_hidden_layers = 32
105
90
self .output_hidden_states = output_hidden_states
106
- self .contrastive_hidden_size = contrastive_hidden_size
107
- self .logit_scale_init_value = logit_scale_init_value
108
91
109
92
def prepare_config_and_inputs (self ):
110
93
input_ids = ids_tensor ([self .batch_size , self .seq_length ], self .vocab_size )
111
94
attention_mask = random_attention_mask ([self .batch_size , self .seq_length ])
112
- pixel_values = floats_tensor ([self .batch_size , self .num_channels , self .image_size , self .image_size ])
113
- pixel_mask = random_attention_mask ([self .batch_size , self .image_size , self .image_size ])
95
+
114
96
config = self .get_config ()
115
- return (config , input_ids , attention_mask , pixel_values , pixel_mask )
97
+
98
+ return config , input_ids , attention_mask
116
99
117
100
def get_config (self ):
118
- return BridgeTowerConfig (
119
- share_cross_modal_transformer_layers = self .share_cross_modal_transformer_layers ,
120
- drop_rate = self .drop_rate ,
121
- head_hidden_scale = self .head_hidden_scale ,
101
+ return BridgeTowerTextConfig (
122
102
hidden_act = self .hidden_act ,
123
103
hidden_size = self .hidden_size ,
124
104
initializer_factor = self .initializer_factor ,
125
- image_size = self .image_size ,
126
- is_encoder_decoder = self .is_encoder_decoder ,
127
105
layer_norm_eps = self .layer_norm_eps ,
128
- share_link_tower_layers = self .share_link_tower_layers ,
129
- link_tower_type = self .link_tower_type ,
130
106
num_attention_heads = self .num_attention_heads ,
131
107
num_hidden_layers = self .num_hidden_layers ,
108
+ intermediate_size = self .intermediate_size ,
132
109
tie_word_embeddings = self .
F438
tie_word_embeddings ,
110
+ output_hidden_states = self .output_hidden_states ,
111
+ )
112
+
113
+
114
+ class BridgeTowerImageModelTester :
115
+ def __init__ (
116
+ self ,
117
+ parent ,
118
+ hidden_size = 128 ,
119
+ initializer_factor = 1 ,
120
+ layer_norm_eps = 1e-05 ,
121
+ num_hidden_layers = 2 ,
122
+ init_layernorm_from_vision_encoder = False ,
123
+ output_hidden_states = False ,
124
+ image_size = 64 ,
125
+ ):
126
+ self .parent = parent
127
+ self .hidden_size = hidden_size
128
+ self .initializer_factor = initializer_factor
129
+ self .layer_norm_eps = layer_norm_eps
130
+ self .num_hidden_layers = num_hidden_layers
131
+ self .init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
132
+ self .num_channels = 3
133
+ self .num_image_features = 17
134
+ self .batch_size = 1
135
+ self .image_size = image_size
136
+ self .is_training = False
137
+ self .output_hidden_states = output_hidden_states
138
+
139
+ def prepare_config_and_inputs (self ):
140
+ pixel_values = floats_tensor ([self .batch_size , self .num_channels , self .image_size , self .image_size ])
141
+ pixel_mask = random_attention_mask ([self .batch_size , self .image_size , self .image_size ])
142
+ config = self .get_config ()
143
+
144
+ return config , pixel_values , pixel_mask
145
+
146
+ def get_config (self ):
147
+ return BridgeTowerVisionConfig (
148
+ hidden_size = self .hidden_size ,
149
+ initializer_factor = self .initializer_factor ,
150
+ layer_norm_eps = self .layer_norm_eps ,
151
+ num_hidden_layers = self .num_hidden_layers ,
133
152
init_layernorm_from_vision_encoder = self .init_layernorm_from_vision_encoder ,
134
153
num_channels = self .num_channels ,
154
+ num_image_features = self .num_image_features ,
155
+ batch_size = self .batch_size ,
156
+ image_size = self .image_size ,
157
+ is_training = self .is_training ,
135
158
output_hidden_states = self .output_hidden_states ,
159
+ )
160
+
161
+
162
+ class BridgeTowerModelTester :
163
+ def __init__ (
164
+ self ,
165
+ parent ,
166
+ text_kwargs = None ,
167
+ vision_kwargs = None ,
168
+ share_cross_modal_transformer_layers = True ,
169
+ share_link_tower_layers = False ,
170
+ link_tower_type = "add" ,
171
+ init_layernorm_from_vision_encoder = False ,
172
+ contrastive_hidden_size = 512 ,
173
+ logit_scale_init_value = 2.6592 ,
174
+ hidden_size = 128 ,
175
+ num_hidden_layers = 2 ,
176
+ num_attention_heads = 4 ,
177
+ intermediate_size = 256 ,
178
+ ):
179
+ if text_kwargs is None :
180
+ text_kwargs = {}
181
+ if vision_kwargs is None :
182
+ vision_kwargs = {}
183
+
184
+ self .parent = parent
185
+ self .text_model_tester = BridgeTowerTextModelTester (parent , ** text_kwargs )
186
+ self .vision_model_tester = BridgeTowerImageModelTester (parent , ** vision_kwargs )
187
+
188
+ self .share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
189
+ self .share_link_tower_layers = share_link_tower_layers
190
+ self .link_tower_type = link_tower_type
191
+ self .init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
192
+ self .contrastive_hidden_size = contrastive_hidden_size
193
+ self .logit_scale_init_value = logit_scale_init_value
194
+
195
+ self .batch_size = 1
196
+ self .expected_num_hidden_layers = 8
197
+ self .is_training = False
198
+
199
+ self .hidden_size = hidden_size
200
+ self .num_hidden_layers = num_hidden_layers
201
+ self .num_attention_heads = num_attention_heads
202
+ self .intermediate_size = intermediate_size
203
+
204
+ def prepare_config_and_inputs (self ):
205
+ text_config , input_ids , attention_mask = self .text_model_tester .prepare_config_and_inputs ()
206
+ vision_config , pixel_values , pixel_mask = self .vision_model_tester .prepare_config_and_inputs ()
207
+
208
+ config = self .get_config ()
209
+
210
+ return (config , input_ids , attention_mask , pixel_values , pixel_mask )
211
+
212
+ def get_config (self ):
213
+ return BridgeTowerConfig .from_text_vision_configs (
214
+ text_config = self .text_model_tester .get_config (),
215
+ vision_config = self .vision_model_tester .get_config (),
216
+ share_cross_modal_transformer_layers = self .share_cross_modal_transformer_layers ,
217
+ share_link_tower_layers = self .share_link_tower_layers ,
218
+ link_tower_type = self .link_tower_type ,
219
+ init_layernorm_from_vision_encoder = self .init_layernorm_from_vision_encoder ,
136
220
contrastive_hidden_size = self .contrastive_hidden_size ,
137
221
logit_scale_init_value = self .logit_scale_init_value ,
222
+ hidden_size = self .hidden_size ,
223
+ num_hidden_layers = self .num_hidden_layers ,
224
+ num_attention_heads = self .num_attention_heads ,
225
+ intermediate_size = self .intermediate_size ,
138
226
)
139
227
140
228
def create_and_check_model (
@@ -150,11 +238,18 @@ def create_and_check_model(
150
238
model .eval ()
151
239
result = model (input_ids , attention_mask = attention_mask , pixel_values = pixel_values , pixel_mask = pixel_mask )
152
240
result = model (input_ids , attention_mask = attention_mask , pixel_values = pixel_values )
153
- self .parent .assertEqual (result ["text_features" ].shape , (self .batch_size , self .seq_length , self .hidden_size ))
154
241
self .parent .assertEqual (
155
- result ["image_features" ].shape , (self .batch_size , self .num_image_features , self .hidden_size )
242
+ result ["text_features" ].shape ,
243
+ (self .batch_size , self .text_model_tester .seq_length , self .text_model_tester .hidden_size ),
244
+ )
245
+ self .parent .assertEqual (
246
+ result ["image_features" ].shape ,
247
+ (self .batch_size , self .vision_model_tester .num_image_features , self .vision_model_tester .hidden_size ),
248
+ )
249
+ self .parent .assertEqual (
250
+ result ["pooler_output" ].shape ,
251
+ (self .batch_size , self .text_model_tester .hidden_size + self .vision_model_tester .hidden_size ),
156
252
)
157
- self .parent .assertEqual (result ["pooler_output" ].shape , (self .batch_size , 2 * self .hidden_size ))
158
253
159
254
def create_and_check_for_image_and_text_retrieval (
160
255
self ,
@@ -188,7 +283,7 @@ def create_and_check_for_masked_language_modeling(
188
283
result = model (input_ids , attention_mask = attention_mask , pixel_values = pixel_values , pixel_mask = pixel_mask )
189
284
result = model (input_ids , attention_mask = attention_mask , pixel_values = pixel_values )
190
285
191
- self .parent .assertEqual (result .logits .shape , (self .batch_size , self .seq_length , 50265 ))
286
+ self .parent .assertEqual (result .logits .shape , (self .batch_size , self .text_model_tester . seq_length , 50265 ))
192
287
193
288
def prepare_config_and_inputs_for_common (self ):
194
289
config_and_inputs = self .prepare_config_and_inputs ()
@@ -202,7 +297,6 @@ def prepare_config_and_inputs_for_common(self):
202
297
return config , inputs_dict
203
298
204
299
205
- @slow
206
300
@require_torch
207
301
@unittest .skipIf (not is_torch_greater_or_equal_than_1_10 , "BridgeTower is only available in torch v1.10+" )
208
302
class BridgeTowerModelTest (ModelTesterMixin , PipelineTesterMixin , unittest .TestCase ):
@@ -225,6 +319,18 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
225
319
test_resize_embeddings = False
226
320
has_attentions = False
227
321
322
+ @unittest .skip (reason = "Does not work on the tiny model as we keep hitting edge cases." )
323
+ def test_cpu_offload (self ):
324
+ pass
325
+
326
+ @unittest .skip (reason = "Does not work on the tiny model as we keep hitting edge cases." )
327
+ def test_disk_offload (self ):
328
+ pass
329
+
330
+ @unittest .skip (reason = "Does not work on the tiny model as we keep hitting edge cases." )
331
+ def test_model_parallelism (self ):
332
+ pass
333
+
228
334
# function to extract meaningful tensor from output per different model_class
229
335
def extract_output (self , outputs , model_class ):
230
336
return outputs ["pooler_output" ] if model_class == "BridgeTowerModel" else outputs ["logits" ]
@@ -301,32 +407,30 @@ def check_hidden_states_output(inputs_dict, config, model_class):
301
407
outputs .encoder_hidden_states if config .is_encoder_decoder else outputs .hidden_states
302
408
)
303
409
304
- expected_num_layers = getattr (
305
- self .model_tester , "expected_num_hidden_layers" , self .model_tester .num_hidden_layers + 1
306
- )
410
+ expected_num_layers = self .model_tester .expected_num_hidden_layers
307
411
self .assertEqual (
308
412
sum ((len (hidden_states_text ), len (hidden_states_vision ), len (hidden_states_cross ))),
309
413
expected_num_layers ,
310
414
)
311
415
312
- seq_length = self .model_tester .seq_length
313
- num_image_features = self .model_tester .num_image_features
416
+ seq_length = self .model_tester .text_model_tester . seq_length
417
+ num_image_features = self .model_tester .vision_model_tester . num_image_features
314
418
315
419
self .assertListEqual (
316
420
list (hidden_states_text [0 ].shape [- 2 :]),
317
- [seq_length , self .model_tester .hidden_size ],
421
+ [seq_length , self .model_tester .text_model_tester . hidden_size ],
318
422
)
319
423
self .assertListEqual (
320
424
list (hidden_states_vision [0 ].shape ),
321
- [num_image_features , 1 , self .model_tester .hidden_size ],
425
+ [num_image_features , 1 , self .model_tester .vision_model_tester . hidden_size ],
322
426
)
323
427
self .assertListEqual (
324
428
list (hidden_states_cross [0 ][0 ].shape [- 2 :]),
325
- [seq_length , self .model_tester .hidden_size ],
429
+ [seq_length , self .model_tester .text_model_tester . hidden_size ],
326
430
)
327
431
self .assertListEqual (
328
432
list (hidden_states_cross [0 ][1 ].shape [- 2 :]),
329
- [num_image_features , self .model_tester .hidden_size ],
433
+ [num_image_features , self .model_tester .vision_model_tester . hidden_size ],
330
434
)
331
435
332
436
config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
0 commit comments