8000 Fix LayerNorm(bias=False) error (#108060) · pytorch/pytorch@584a01b · GitHub
[go: up one dir, main page]

Skip to content

Commit 584a01b

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Fix LayerNorm(bias=False) error (#108060)
Fixes #108048 - [ ] Cherry pick this [here](#108055) Pull Request resolved: #108060 Approved by: https://github.com/jbschlosser, https://github.com/albanD, https://github.com/malfet
1 parent 054f3f1 commit 584a01b

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

torch/nn/modules/normalization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_af
189189
def reset_parameters(self) -> None:
190190
if self.elementwise_affine:
191191
init.ones_(self.weight)
192-
init.zeros_(self.bias)
192+
if self.bias is not None:
193+
init.zeros_(self.bias)
193194

194195
def forward(self, input: Tensor) -> Tensor:
195196
return F.layer_norm(

torch/testing/_internal/common_modules.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,10 @@ def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad,
14141414
constructor_input=FunctionInput([5], 1e-3),
14151415
forward_input=FunctionInput(make_input((0, 5))),
14161416
desc='1d_empty_elementwise_affine'),
1417+
ModuleInput(
1418+
constructor_input=FunctionInput([2, 2, 5], 1e-3, elementwise_affine=True, bias=False),
1419+
forward_input=FunctionInput(make_input((4, 2, 2, 5))),
1420+
desc='3d_elementwise_affine_no_bias'),
14171421
]
14181422

14191423

@@ -1809,15 +1813,16 @@ def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad
18091813
# Samples below are for validating the no-batch-dim support.
18101814
key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
18111815
attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
1812-
for mask, key_padding_mask, norm_first in itertools.product(attn_masks, key_padding_masks, (True, False)):
1816+
for mask, key_padding_mask, norm_first, bias in \
1817+
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False)):
18131818
# Using same mask for tgt and memory
18141819
src_mask , tgt_mask = (mask,) * 2
18151820
src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2
18161821
samples.append(
18171822
ModuleInput(
18181823
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
18191824
num_encoder_layers=1, num_decoder_layers=1,
1820-
dropout=0.0, batch_first=True, norm_first=norm_first),
1825+
dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias),
18211826
forward_input=FunctionInput(
18221827
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
18231828
tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask

0 commit comments

Comments
 (0)
0