8000 FIX Use correct argument name in MHA forward (#2510) · huggingface/peft@eb5e9bc · GitHub
[go: up one dir, main page]

Skip to content

Commit eb5e9bc

Browse files
FIX Use correct argument name in MHA forward (#2510)
The arguments of the forward method of MultiheadAttention are called query etc. PEFT used x. Therefore, if a caller uses keywords only, the argument is not assigned, resulting in an error. This was initially reported here: #761 (comment) Note: Other layers' forward method (like Linear) also uses incorrect names, like x instead of input, but so far no issues were reported, so I'll leave it as is for now.
1 parent 1fb98f1 commit eb5e9bc

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/peft/tuners/lora/layer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,16 +1648,16 @@ def _check_forward_args(self, x, *args, **kwargs):
16481648
raise TypeError(f"lora.{self.__class__.__name__} does not support mixed adapter batches.")
16491649
super()._check_forward_args(x, *args, **kwargs)
16501650

1651-
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
1652-
previous_dtype = x.dtype
1653-
self._check_forward_args(x, *args, **kwargs)
1651+
def forward(self, query: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
1652+
previous_dtype = query.dtype
1653+
self._check_forward_args(query, *args, **kwargs)
16541654

16551655
if self.disable_adapters:
16561656
if self.merged:
16571657
self.unmerge()
1658-
result = self.base_layer(x, *args, **kwargs)
1658+
result = self.base_layer(query, *args, **kwargs)
16591659
elif self.merged:
1660-
result = self.base_layer(x, *args, **kwargs)
1660+
result = self.base_layer(query, *args, **kwargs)
16611661
else:
16621662
out_proj = self.get_base_layer().out_proj
16631663
if out_proj.active_adapters != self.active_adapters:
@@ -1680,7 +1680,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
16801680
active_adapters = [a for a in self.active_adapters if a in self.lora_A]
16811681
try:
16821682
self.merge(adapter_names=active_adapters)
1683-
result = self.base_layer(x, *args, **kwargs)
1683+
result = self.base_layer(query, *args, **kwargs)
16841684
finally:
16851685
# it's safe to call unmerge(), which unmerges all adapters, because we checked that not self.merged,
16861686
# i.e. there is was no merged layer before

0 commit comments

Comments
 (0)
0