8000 Add missing arg descriptions for class RMSNorm(Module) by svekars · Pull Request #153738 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add missing arg descriptions for class RMSNorm(Module) #153738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
25 changes: 25 additions & 0 deletions torch/nn/modules/_nn_docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# mypy: allow-untyped-defs
"""Adds docstrings to functions defined in the torch.nn module."""
from torch._torch_docs import parse_kwargs


# Common parameter documentation for nn modules
common_args = parse_kwargs(
"""
device: the device on which the parameters will be allocated. Default: None
dtype: the data type of the parameters. Default: None
"""
)

layernorm_args = parse_kwargs(
"""
normalized_shape: input shape from an expected input of size
[* x normalized_shape[0] x normalized_shape[1] x ... x normalized_shape[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
"""
)
43 changes: 19 additions & 24 deletions torch/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import torch
from torch import Size, Tensor
from torch._C import _add_docstr as add_docstr
from torch.nn import functional as F, init
from torch.nn.modules._nn_docs import common_args, layernorm_args
from torch.nn.parameter import Parameter

from ._functions import CrossMapLRN2d as _cross_map_lrn2d
Expand Down Expand Up @@ -92,13 +94,14 @@


class LayerNorm(Module):
r"""Applies Layer Normalization over a mini-batch of inputs.
__doc__ = r"""Applies Layer Normalization over a mini-batch of inputs.

This layer implements the operation as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
y = \frac{{x - \mathrm{{E}}[x]}}{{ \sqrt{{\mathrm{{Var}}[x] + \epsilon}} }} * \gamma + \beta


The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
Expand All @@ -119,28 +122,15 @@
evaluation modes.

Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size

.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]

If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
:attr:`elementwise_affine` is ``True``). Default: ``True``.
{layernorm_args}
{common_args}

Attributes:
weight: the learnable weights of the module of shape
:math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
:math:`\\text{{normalized\\_shape}}` when :attr:`elementwise_affine` is set to ``True``.
The values are initialized to 1.
bias: the learnable bias of the module of shape
:math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
:math:`\\text{{normalized\\_shape}}` when :attr:`elementwise_affine` is set to ``True``.
The values are initialized to 0.

Shape:
Expand All @@ -167,7 +157,9 @@
.. image:: ../_static/img/nn/layer_norm.jpg
:scale: 50 %

"""
""".format(
layernorm_args=layernorm_args.__doc__, common_args=common_args.__doc__

Check failure on line 161 in torch/nn/modules/normalization.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [has-type]

Cannot determine type of "common_args"

Check failure on line 161 in torch/nn/modules/normalization.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [has-type]

Cannot determine type of "layernorm_args"
)

__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: tuple[int, ...]
Expand Down Expand Up @@ -325,8 +317,8 @@
the paper `Root Mean Square Layer Normalization <https://arxiv.org/pdf/1910.07467.pdf>`__

.. math::
y_i = \frac{x_i}{\mathrm{RMS}(x)} * \gamma_i, \quad
\text{where} \quad \text{RMS}(x) = \sqrt{\epsilon + \frac{1}{n} \sum_{i=1}^{n} x_i^2}
y_i = \frac{x_i}{{\mathrm{{RMS}}(x)}} * \gamma_i, \quad
\text{{where}} \quad \text{{RMS}}(x) = \sqrt{{\epsilon + \frac{{1}}{{n}} \sum_{{i=1}}^{{n}} x_i^2}}

The RMS is taken over the last ``D`` dimensions, where ``D``
is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
Expand All @@ -338,14 +330,16 @@
of size

.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
[* \times \text{{normalized\_shape}}[0] \times \text{{normalized\_shape}}[1]
\times \ldots \times \text{{normalized\_shape}}[-1]]

If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``.
{device}
{dtype}

Shape:
- Input: :math:`(N, *)`
Expand All @@ -358,6 +352,7 @@
>>> rms_norm(input)

"""

__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: tuple[int, ...]
eps: Optional[float]
Expand Down
Loading
0