8000 Update rnn.py, fix `torch.nn.RNN` document error by AIboy996 · Pull Request #153620 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Update rnn.py, fix torch.nn.RNN document error #153620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

AIboy996
Copy link
@AIboy996 AIboy996 commented May 15, 2025

I found the same issue as #147490 (@jibril-b-coulibaly).

There's an equivalent in the doc-string of torch.nn.RNN:

# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, hx=None):
    if batch_first:
        x = x.transpose(0, 1)
    seq_len, batch_size, _ = x.size()
    if hx is None:
        hx = torch.zeros(num_layers, batch_size, hidden_size)
    h_t_minus_1 = hx
    h_t = hx
    output = []
    for t in range(seq_len):
        for layer in range(num_layers):
            h_t[layer] = torch.tanh(
                x[t] @ weight_ih[layer].T
                + bias_ih[layer]
                + h_t_minus_1[layer] @ weight_hh[layer].T
                + bias_hh[layer]
            )
        output.append(h_t[-1])
        h_t_minus_1 = h_t
    output = torch.stack(output)
    if batch_first:
        output = output.transpose(0, 1)
    return output, h_t

However there's something wrong.

  1. Like mentioned in Documentation: fix RNN example for multiple layers #147490, line 499 is wrong

x[t] @ weight_ih[layer].T

The input for RNNCell should be different for different layers.

  1. The code contains several hidden reference-related issues that may result in unintended modifications to tensors. For example in line 504, this causes all elements in the final output list to point to the same tensor.

output.append(h_t[-1])

  1. Some variable is not defined. Despite being a relatively minor issue in annotation, it can lead to significant confusion for those who are new to the concept. For example weight_ih in line 499

x[t] @ weight_ih[layer].T

So, i write a runnable version to make it more clear:

# Efficient implementation equivalent to the following with bidirectional=False
rnn = nn.RNN(input_size, hidden_size, num_layers)
params = dict(rnn.named_parameters())
def forward(x, hx=None, batch_first=False):
    if batch_first:
        x = x.transpose(0, 1)
    seq_len, batch_size, _ = x.size()
    if hx is None:
        hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size)
    h_t_minus_1 = hx.clone()
    h_t = hx.clone()
    output = []
    for t in range(seq_len):
        for layer in range(rnn.num_layers):
            input_t = x[t] if layer == 0 else h_t[layer - 1]
            h_t[layer] = torch.tanh(
                input_t @ params[f"weight_ih_l{layer}"].T
                + h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T
                + params[f"bias_hh_l{layer}"]
                + params[f"bias_ih_l{layer}"]
            )
        output.append(h_t[-1].clone())
        h_t_minus_1 = h_t.clone()
    output = torch.stack(output)
    if batch_first:
        output = output.transpose(0, 1)
    return output, h_t

This code can reproduce the computation of torch.nn.RNN.

For example:

import torch
import torch.nn as nn

torch.manual_seed(0)
input_size, hidden_size, num_layers = 3, 5, 2
rnn = nn.RNN(input_size, hidden_size, num_layers)
params = dict(rnn.named_parameters())
x = torch.randn(10, 4, 3)


official_imp = rnn(x)
my_imp = forward(x)

assert torch.allclose(official_imp[0], my_imp[0])
assert torch.allclose(official_imp[1], my_imp[1])

cc @svekars @sekyondaMeta @AlannaBurke @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Copy link
pytorch-bot bot commented May 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153620

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit eadd6c0 with merge base 7482eb2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
linux-foundation-easycla bot commented May 15, 2025

CLA Signed


The committers listed above are authorized under a signed CLA.

@AIboy996
Copy link
Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 15, 2025
@AIboy996
Copy link
Author

@pytorchbot label "documentation"

Copy link
pytorch-bot bot commented May 15, 2025

Didn't find following labels among repository labels: documentation

@AIboy996
Copy link
Author

@pytorchbot label "module: docs"

@pytorch-bot pytorch-bot bot added the module: docs Related to our documentation, both in docs/ and docblocks label May 15, 2025
@AIboy996
Copy link
Author

@pytorchbot label "module: nn"

@pytorch-bot pytorch-bot bot added the module: nn Related to torch.nn label May 15, 2025
@albanD albanD removed their request for review May 15, 2025 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn open source topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0