@@ -981,7 +981,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
981
981
config_class = BridgeTowerConfig
982
982
base_model_prefix = "bridgetower"
983
983
supports_gradient_checkpointing = False
984
- _no_split_modules = ["BridgeTowerSelfAttention" ]
984
+ _no_split_modules = ["BridgeTowerSelfAttention" , "BridgeTowerResidualAttention" ]
985
985
986
986
def _init_weights (self , module ):
987
987
if isinstance (module , BridgeTowerVisionModel ):
@@ -1863,12 +1863,16 @@ def forward(
1863
1863
1864
1864
# normalized features
1865
1865
text_embeds = nn .functional .normalize (self .itc_text_head (text_embeds [:, 0 , :]), dim = - 1 , p = 2 )
1866
- image_embeds = nn .functional .normalize (self .itc_image_head (image_embeds [:, 0 , :]), dim = - 1 , p = 2 )
1867
- cross_embeds = nn .functional .normalize (self .itc_cross_modal_head (pooler_output ), dim = - 1 , p = 2 )
1866
+ image_embeds = nn .functional .normalize (self .itc_image_head (image_embeds [:, 0 , :]), dim = - 1 , p = 2 ).to (
1867
+ device = text_embeds .device
1868
+ )
1869
+ cross_embeds = nn .functional .normalize (self .itc_cross_modal_head (pooler_output ), dim = - 1 , p = 2 ).to (
1870
+ device = text_embeds .device
1871
+ )
1868
1872
1869
1873
logits = torch .stack ([text_embeds , image_embeds , cross_embeds ], dim = - 2 )
1870
1874
1871
- logit_scale = self .logit_scale .exp ()
1875
+ logit_scale = self .logit_scale .exp (). to ( device = text_embeds . device )
1872
1876
logits_text_to_image = torch .matmul (text_embeds , image_embeds .t ()) * logit_scale
1873
1877
logits_text_to_cross = torch .matmul (text_embeds , cross_embeds .t ()) * logit_scale
1874
1878
logits_image_to_cross = torch .matmul (image_embeds , cross_embeds .t ()) * logit_scale
0 commit comments