8000 Update · pytorch/pytorch@477d0ff · GitHub
[go: up one dir, main page]

Skip to content

Commit 477d0ff

Browse files
committed
Update
1 parent f2e981b commit 477d0ff

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

torch/nn/modules/_nn_docs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Common parameter documentation for nn modules
2+
COMMON_ARGS = """
3+
device: the device on which the parameters will be allocated. Default: None
4+
dtype: the data type of the parameters. Default: None
5+
"""

torch/nn/modules/normalization.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.nn.parameter import Parameter
99

1010
from ._functions import CrossMapLRN2d as _cross_map_lrn2d
11+
from ._nn_docs import COMMON_ARGS
1112
from .module import Module
1213

1314

@@ -134,9 +135,6 @@ class LayerNorm(Module):
134135
and zeros (for biases). Default: ``True``.
135136
bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
136137
:attr:`elementwise_affine` is ``True``). Default: ``True``.
137-
device: the device on which the parameters will be allocated. Default: None
138-
dtype: the data type of the parameters. Default: None
139-
140138
141139
Attributes:
142140
weight: the learnable weights of the module of shape
@@ -170,7 +168,9 @@ class LayerNorm(Module):
170168
.. image:: ../_static/img/nn/layer_norm.jpg
171169
:scale: 50 %
172170
173-
"""
171+
""".format(
172+
COMMON_ARGS
173+
)
174174

175175
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
176176
normalized_shape: tuple[int, ...]
@@ -272,7 +272,9 @@ class GroupNorm(Module):
272272
>>> m = nn.GroupNorm(1, 6)
273273
>>> # Activating the module
274274
>>> output = m(input)
275-
"""
275+
""".format(
276+
COMMON_ARGS
277+
)
276278

277279
__constants__ = ["num_groups", "num_channels", "eps", "affine"]
278280
num_groups: int
@@ -349,8 +351,6 @@ class RMSNorm(Module):
349351
eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
350352
elementwise_affine: a boolean value that when set to ``True``, this module
351353
has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``.
352-
device: the device on which the parameters will be allocated. Default: None.
353-
dtype: the data type of the parameters. Default: None.
354354
355355
Shape:
356356
- Input: :math:`(N, *)`
@@ -362,7 +362,9 @@ class RMSNorm(Module):
362362
>>> input = torch.randn(2, 2, 3)
363363
>>> rms_norm(input)
364364
365-
"""
365+
""".format(
366+
COMMON_ARGS
367+
)
366368

367369
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
368370
normalized_shape: tuple[int, ...]

0 commit comments

Comments
 (0)
0