4
4
5
5
import torch
6
6
from torch import Size , Tensor
7
+ from torch ._C import _add_docstr as add_docstr
7
8
from torch .nn import functional as F , init
9
+ from torch .nn .modules ._nn_docs import common_args
8
10
from torch .nn .parameter import Parameter
9
11
10
12
from ._functions import CrossMapLRN2d as _cross_map_lrn2d
11
- from ._nn_docs import COMMON_ARGS
12
13
from .module import Module
13
14
14
-
15
15
__all__ = ["LocalResponseNorm" , "CrossMapLRN2d" , "LayerNorm" , "GroupNorm" , "RMSNorm" ]
16
16
17
17
@@ -135,6 +135,8 @@ class LayerNorm(Module):
135
135
and zeros (for biases). Default: ``True``.
136
136
bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
137
137
:attr:`elementwise_affine` is ``True``). Default: ``True``.
138
+ {device}
139
+ {dtype}
138
140
139
141
Attributes:
140
142
weight: the learnable weights of the module of shape
@@ -168,9 +170,7 @@ class LayerNorm(Module):
168
170
.. image:: ../_static/img/nn/layer_norm.jpg
169
171
:scale: 50 %
170
172
171
- """ .format (
172
- COMMON_ARGS
173
- )
173
+ """
174
174
175
175
__constants__ = ["normalized_shape" , "eps" , "elementwise_affine" ]
176
176
normalized_shape : tuple [int , ...]
@@ -228,6 +228,9 @@ def extra_repr(self) -> str:
228
228
)
229
229
230
230
231
+ add_docstr (LayerNorm , LayerNorm .__doc__ .format (** common_args ))
232
+
233
+
231
234
class GroupNorm (Module ):
232
235
r"""Applies Group Normalization over a mini-batch of inputs.
233
236
@@ -256,6 +259,8 @@ class GroupNorm(Module):
256
259
affine: a boolean value that when set to ``True``, this module
257
260
has learnable per-channel affine parameters initialized to ones (for weights)
258
261
and zeros (for biases). Default: ``True``.
262
+ {device}
263
+ {dtype}
259
264
260
265
Shape:
261
266
- Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
@@ -273,7 +278,7 @@ class GroupNorm(Module):
273
278
>>> # Activating the module
274
279
>>> output = m(input)
275
280
""" .format (
276
- COMMON_ARGS
281
+ ** common_args
277
282
)
278
283
279
284
__constants__ = ["num_groups" , "num_channels" , "eps" , "affine" ]
@@ -351,6 +356,8 @@ class RMSNorm(Module):
351
356
eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
352
357
elementwise_affine: a boolean value that when set to ``True``, this module
353
358
has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``.
359
+ {device}
360
+ {dtype}
354
361
355
362
Shape:
356
363
- Input: :math:`(N, *)`
@@ -363,7 +370,7 @@ class RMSNorm(Module):
363
370
>>> rms_norm(input)
364
371
365
372
""" .format (
366
- COMMON_ARGS
373
+ ** common_args
367
374
)
368
375
369
376
__constants__ = ["normalized_shape" , "eps" , "elementwise_affine" ]
0 commit comments