[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent gradient from different backend of CTCLoss #26797

Closed
Alexander-H-Liu opened this issue Sep 25, 2019 · 2 comments
Closed

Inconsistent gradient from different backend of CTCLoss #26797

Alexander-H-Liu opened this issue Sep 25, 2019 · 2 comments
Labels
module: cudnn Related to torch.backends.cudnn, and CuDNN support module: loss Problem is related to loss function triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Alexander-H-Liu
Copy link
Alexander-H-Liu commented Sep 25, 2019

🐛 Bug

I'm switching to the cudnn backend of CTCLoss since the other is not fully reproducible,
however, it turns out that the exact same model that used to work with pytorch's cuda backend ctc loss now failed.
With some simple example, I found that there's a huge difference in gradient (both direction and magnitude) between two backends.
I'm not sure if the bug is on pytorch or cudnn, but as far as I know, TensorFlow also used CTC from cudnn and there is no similar issue.
Thanks in advance.

To Reproduce

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

batch_size = 16
seq_len = 50
target_len = seq_len//2
latent_dim = 128
vocab_size = 50

x = torch.randn((seq_len,batch_size,latent_dim)).cuda()
y = torch.randint(1,vocab_size,(batch_size*target_len,),dtype=torch.long).cuda()
x_len = seq_len*torch.ones((batch_size,),dtype=torch.long).cuda()
y_len = target_len*torch.ones((batch_size,),dtype=torch.long).cuda()
w = torch.nn.Linear(latent_dim,vocab_size).cuda()

def compute_ctc(x,y,x_len,y_len,use_cudnn):
    for p in w.parameters():
        if p.grad is not None:
            p.grad.zero_()
    # Forward
    output = w(x).log_softmax(dim=-1)
    if use_cudnn:
        loss = torch.nn.functional.ctc_loss(output,
                                            y.to('cpu',torch.int32),
                                            x_len.to('cpu',torch.int32),
                                            y_len.to('cpu',torch.int32))
    else:
        loss = torch.nn.functional.ctc_loss(output,y,x_len,y_len)
    # backward
    loss.backward()
    m, b = w.parameters()
    print('loss = {}\ngrad_norm = {}'.format(loss, m.grad.view(-1).norm()))
    return m.grad.clone()

print("===== Pytorch CTC =====")
torch_grad = compute_ctc(x,y,x_len,y_len,False)
print("===== Cudnn CTC =====")
cudnn_grad = compute_ctc(x,y,x_len,y_len,True)
print("===== Grad diff. =====")
print("Cos Sim. = ",torch.nn.functional.cosine_similarity(torch_grad.view(-1),cudnn_grad.view(-1),dim=0))
print("Magnitude  = ",cudnn_grad.view(-1).norm() / torch_grad.view(-1).norm())

Behavior

===== Pytorch CTC =====
loss = 6.073373794555664
grad_norm = 0.4746553599834442
===== Cudnn CTC =====
loss = 6.073373794555664
grad_norm = 10431.4375
===== Grad diff. =====
Cos Sim. =  tensor(0.6508, device='cuda:0')
Magnitude  =  tensor(21976.8672, device='cuda:0')

Environment

PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti

Nvidia driver version: 430.50
cuDNN version: /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7

Versions of relevant libraries:
[pip3] numpy==1.17.2
[pip3] torch==1.2.0
[pip3] torchaudio==0.3.0
[pip3] torchvision==0.4.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519

@Alexander-H-Liu Alexander-H-Liu changed the title Inconsistent gradient from the different backend of CTCLoss Inconsistent gradient from different backend of CTCLoss Sep 25, 2019
@yf225
Copy link
Contributor
yf225 commented Sep 26, 2019

Likely related issues:
#25833
#17798
#12201

I can reproduce this. @t-vi Do you know if this difference is expected?

@yf225 yf225 added module: loss Problem is related to loss function triage review labels Sep 26, 2019
t-vi added a commit to t-vi/pytorch that referenced this issue Sep 29, 2019
@t-vi
Copy link
Collaborator
t-vi commented Sep 29, 2019

I think this and #25833 are the same, the others I'm not so sure about.

@vincentqb vincentqb added module: cudnn Related to torch.backends.cudnn, and CuDNN support and removed triage review labels Sep 30, 2019
@ailzhang ailzhang added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed high priority labels Oct 1, 2019
zdevito pushed a commit to zdevito/ATen that referenced this issue Oct 15, 2019
Summary:
Using grad_out for CuDNN CTC loss fixes: pytorch/pytorch#26797, pytorch/pytorch#25833.

We also fix a cudnn incompatible change that surfaced during the testing: As of CuDNN 7.6 the semantics of the CTC loss gradients are different.
This leads us to disable CuDNN CTC for CuDNN < 7.6. To mitigate the impact on users, we convert the parameters for the native implementation if CuDNN isn't applicable (previously this would give an error.)
Pull Request resolved: pytorch/pytorch#27039

Differential Revision: D17910815

Pulled By: ngimel

fbshipit-source-id: 465b33612d3402f10c355aa7026a7e1ffaef3073
thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this issue Feb 4, 2020
Summary:
Using grad_out for CuDNN CTC loss fixes: pytorch#26797, pytorch#25833.

We also fix a cudnn incompatible change that surfaced during the testing: As of CuDNN 7.6 the semantics of the CTC loss gradients are different.
This leads us to disable CuDNN CTC for CuDNN < 7.6. To mitigate the impact on users, we convert the parameters for the native implementation if CuDNN isn't applicable (previously this would give an error.)
Pull Request resolved: pytorch#27039

Differential Revision: D17910815

Pulled By: ngimel

fbshipit-source-id: 465b33612d3402f10c355aa7026a7e1ffaef3073
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cudnn Related to torch.backends.cudnn, and CuDNN support module: loss Problem is related to loss function triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants