@@ -1414,6 +1414,10 @@ def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad,
1414
1414
constructor_input = FunctionInput ([5 ], 1e-3 ),
1415
1415
forward_input = FunctionInput (make_input ((0 , 5 ))),
1416
1416
desc = '1d_empty_elementwise_affine' ),
1417
+ ModuleInput (
1418
+ constructor_input = FunctionInput ([2 , 2 , 5 ], 1e-3 , elementwise_affine = True , bias = False ),
1419
+ forward_input = FunctionInput (make_input ((4 , 2 , 2 , 5 ))),
1420
+ desc = '3d_elementwise_affine_no_bias' ),
1417
1421
]
1418
1422
1419
1423
@@ -1809,15 +1813,16 @@ def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad
1809
1813
# Samples below are for validating the no-batch-dim support.
1810
1814
key_padding_masks = (None , torch .tensor ([False , False , True ], device = device , dtype = torch .bool ))
1811
1815
attn_masks = (None , torch .tensor ([False , False , True ], device = device , dtype = torch .bool ).expand ((3 , 3 )))
1812
- for mask , key_padding_mask , norm_first in itertools .product (attn_masks , key_padding_masks , (True , False )):
1816
+ for mask , key_padding_mask , norm_first , bias in \
1817
+ itertools .product (attn_masks , key_padding_masks , (True , False ), (True , False )):
1813
1818
# Using same mask for tgt and memory
1814
1819
src_mask , tgt_mask = (mask ,) * 2
1815
1820
src_key_padding_mask , tgt_key_padding_mask = (key_padding_mask ,) * 2
1816
1821
samples .append (
1817
1822
ModuleInput (
1818
1823
constructor_input = FunctionInput (d_model = 4 , nhead = 2 , dim_feedforward = 8 ,
1819
1824
num_encoder_layers = 1 , num_decoder_layers = 1 ,
1820
- dropout = 0.0 , batch_first = True , norm_first = norm_first ),
1825
+ dropout = 0.0 , batch_first = True , norm_first = norm_first , bias = bias ),
1821
1826
forward_input = FunctionInput (
1822
1827
make_input ((3 , 4 )), make_input ((3 , 4 )), tgt_mask = tgt_mask , src_mask = src_mask ,
1823
1828
tgt_key_padding_mask = tgt_key_padding_mask , src_key_padding_mask = src_key_padding_mask
0 commit comments