8000 Update · pytorch/pytorch@595a19d · GitHub
[go: up one dir, main page]

Skip to content

Commit 595a19d

Browse files
committed
Update
1 parent 477d0ff commit 595a19d

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

torch/nn/modules/_nn_docs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1+
# mypy: allow-untyped-defs
2+
"""Adds docstrings to functions defined in the torch.nn module."""
3+
import re
4+
5+
from torch._torch_docs import parse_kwargs
6+
7+
18
# Common parameter documentation for nn modules
2-
COMMON_ARGS = """
9+
common_args = parse_kwargs(
10+
"""
311
device: the device on which the parameters will be allocated. Default: None
412
dtype: the data type of the parameters. Default: None
513
"""
14+
)

torch/nn/modules/normalization.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import torch
66
from torch import Size, Tensor
7+
from torch._C import _add_docstr as add_docstr
78
from torch.nn import functional as F, init
9+
from torch.nn.modules._nn_docs import common_args
810
from torch.nn.parameter import Parameter
911

1012
from ._functions import CrossMapLRN2d as _cross_map_lrn2d
11-
from ._nn_docs import COMMON_ARGS
1213
from .module import Module
1314

14-
1515
__all__ = ["LocalResponseNorm", "CrossMapLRN2d", "LayerNorm", "GroupNorm", "RMSNorm"]
1616

1717

@@ -135,6 +135,8 @@ class LayerNorm(Module):
135135
and zeros (for biases). Default: ``True``.
136136
bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
137137
:attr:`elementwise_affine` is ``True``). Default: ``True``.
138+
{device}
139+
{dtype}
138140
139141
Attributes:
140142
weight: the learnable weights of the module of shape
@@ -168,9 +170,7 @@ class LayerNorm(Module):
168170
.. image:: ../_static/img/nn/layer_norm.jpg
169171
:scale: 50 %
170172
171-
""".format(
172-
COMMON_ARGS
173-
)
173+
"""
174174

175175
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
176176
normalized_shape: tuple[int, ...]
@@ -228,6 +228,9 @@ def extra_repr(self) -> str:
228228
)
229229

230230

231+
add_docstr(LayerNorm, LayerNorm.__doc__.format(**common_args))
232+
233+
231234
class GroupNorm(Module):
232235
r"""Applies Group Normalization over a mini-batch of inputs.
233236
@@ -256,6 +259,8 @@ class GroupNorm(Module):
256259
affine: a boolean value that when set to ``True``, this module
257260
has learnable per-channel affine parameters initialized to ones (for weights)
258261
and zeros (for biases). Default: ``True``.
262+
{device}
263+
{dtype}
259264
260265
Shape:
261266
- Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
@@ -273,7 +278,7 @@ class GroupNorm(Module):
273278
>>> # Activating the module
274279
>>> output = m(input)
275280
""".format(
276-
COMMON_ARGS
281+
**common_args
277282
)
278283

279284
__constants__ = ["num_groups", "num_channels", "eps", "affine"]
@@ -351,6 +356,8 @@ class RMSNorm(Module):
351356
eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
352357
elementwise_affine: a boolean value that when set to ``True``, this module
353358
has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``.
359+
{device}
360+
{dtype}
354361
355362
Shape:
356363
- Input: :math:`(N, *)`
@@ -363,7 +370,7 @@ class RMSNorm(Module):
363370
>>> rms_norm(input)
364371
365372
""".format(
366-
COMMON_ARGS
373+
**common_args
367374
)
368375

369376
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]

0 commit comments

Comments
 (0)
0