8000 Fix model parallelism for `BridgeTower` (#23039) · githubhjs/transformers@b6865b9 · GitHub
[go: up one dir, main page]

Skip to content

Commit b6865b9

Browse files
authored
Fix model parallelism for BridgeTower (huggingface#23039)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent d337631 commit b6865b9

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/transformers/models/bridgetower/modeling_bridgetower.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
981981
config_class = BridgeTowerConfig
982982
base_model_prefix = "bridgetower"
983983
supports_gradient_checkpointing = False
984-
_no_split_modules = ["BridgeTowerSelfAttention"]
984+
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
985985

986986
def _init_weights(self, module):
987987
if isinstance(module, BridgeTowerVisionModel):
@@ -1863,12 +1863,16 @@ def forward(
18631863

18641864
# normalized features
18651865
text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
1866-
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2)
1867-
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)
1866+
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
1867+
device=text_embeds.device
1868+
)
1869+
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
1870+
device=text_embeds.device
1871+
)
18681872

18691873
logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
18701874

1871-
logit_scale = self.logit_scale.exp()
1875+
logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
18721876
logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
18731877
logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
18741878
logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale

0 commit comments

Comments
 (0)
0