8
8
from torch .nn .parameter import Parameter
9
9
10
10
from ._functions import CrossMapLRN2d as _cross_map_lrn2d
11
+ from ._nn_docs import COMMON_ARGS
11
12
from .module import Module
12
13
13
14
@@ -134,9 +135,6 @@ class LayerNorm(Module):
134
135
and zeros (for biases). Default: ``True``.
135
136
bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
136
137
:attr:`elementwise_affine` is ``True``). Default: ``True``.
137
- device: the device on which the parameters will be allocated. Default: None
138
- dtype: the data type of the parameters. Default: None
139
-
140
138
141
139
Attributes:
142
140
weight: the learnable weights of the module of shape
@@ -170,7 +168,9 @@ class LayerNorm(Module):
170
168
.. image:: ../_static/img/nn/layer_norm.jpg
171
169
:scale: 50 %
172
170
173
- """
171
+ """ .format (
172
+ COMMON_ARGS
173
+ )
174
174
175
175
__constants__ = ["normalized_shape" , "eps" , "elementwise_affine" ]
176
176
normalized_shape : tuple [int , ...]
@@ -272,7 +272,9 @@ class GroupNorm(Module):
272
272
>>> m = nn.GroupNorm(1, 6)
273
273
>>> # Activating the module
274
274
>>> output = m(input)
275
- """
275
+ """ .format (
276
+ COMMON_ARGS
277
+ )
276
278
277
279
__constants__ = ["num_groups" , "num_channels" , "eps" , "affine" ]
278
280
num_groups : int
@@ -349,8 +351,6 @@ class RMSNorm(Module):
349
351
eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
350
352
elementwise_affine: a boolean value that when set to ``True``, this module
351
353
has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``.
352
- device: the device on which the parameters will be allocated. Default: None.
353
- dtype: the data type of the parameters. Default: None.
354
354
355
355
Shape:
356
356
- Input: :math:`(N, *)`
@@ -362,7 +362,9 @@ class RMSNorm(Module):
362
362
>>> input = torch.randn(2, 2, 3)
363
363
>>> rms_norm(input)
364
364
365
- """
365
+ """ .format (
366
+ COMMON_ARGS
367
+ )
366
368
367
369
__constants__ = ["normalized_shape" , "eps" , "elementwise_affine" ]
368
370
normalized_shape : tuple [int , ...]
0 commit comments