10000 Inconsistent gradient from different backend of CTCLoss · Issue #26797 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Inconsistent gradient from different backend of CTCLoss #26797
Closed
@Alexander-H-Liu

Description

@Alexander-H-Liu

🐛 Bug

I'm switching to the cudnn backend of CTCLoss since the other is not fully reproducible,
however, it turns out that the exact same model that used to work with pytorch's cuda backend ctc loss now failed.
With some simple example, I found that there's a huge difference in gradient (both direction and magnitude) between two backends.
I'm not sure if the bug is on pytorch or cudnn, but as far as I know, TensorFlow also used CTC from cudnn and there is no similar issue.
Thanks in advance.

To Reproduce

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

batch_size = 16
seq_len = 50
target_len = seq_len//2
latent_dim = 128
vocab_size = 50

x = torch.randn((seq_len,batch_size,latent_dim)).cuda()
y = torch.randint(1,vocab_size,(batch_size*target_len,),dtype=torch.long).cuda()
x_len = seq_len*torch.ones((batch_size,),dtype=torch.long).cuda()
y_len = target_len*torch.ones((batch_size,),dtype=torch.long).cuda()
w = torch.nn.Linear(latent_dim,vocab_size).cuda()

def compute_ctc(x,y,x_len,y_len,use_cudnn):
    for p in w.parameters():
        if p.grad is not None:
            p.grad.zero_()
    # Forward
    output = w(x).log_softmax(dim=-1)
    if use_cudnn:
        loss = torch.nn.functional.ctc_loss(output,
                                            y.to('cpu',torch.int32),
                                            x_len.to('cpu',torch.int32),
                                            y_len.to('cpu',torch.int32))
    else:
        loss = torch.nn.functional.ctc_loss(output,y,x_len,y_len)
    # backward
    loss.backward()
    m, b = w.parameters()
    print('loss = {}\ngrad_norm = {}'.format(loss, m.grad.view(-1).norm()))
    return m.grad.clone()

print("===== Pytorch CTC =====")
torch_grad = compute_ctc(x,y,x_len,y_len,False)
print("===== Cudnn CTC =====")
cudnn_grad = compute_ctc(x,y,x_len,y_len,True)
print("===== Grad diff. =====")
print("Cos Sim. = ",torch.nn.functional.cosine_similarity(torch_grad.view(-1),cudnn_grad.view(-1),dim=0))
print("Magnitude  = ",cudnn_grad.view(-1).norm() / torch_grad.view(-1).norm())

Behavior

===== Pytorch CTC =====
loss = 6.073373794555664
grad_norm = 0.4746553599834442
===== Cudnn CTC =====
loss = 6.073373794555664
grad_norm = 10431.4375
===== Grad diff. =====
Cos Sim. =  tensor(0.6508, device='cuda:0')
Magnitude  =  tensor(21976.8672, device='cuda:0')

Environment

PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti

Nvidia driver version: 430.50
cuDNN version: /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7

Versions of relevant libraries:
[pip3] numpy==1.17.2
[pip3] torch==1.2.0
[pip3] torchaudio==0.3.0
[pip3] torchvision==0.4.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudnnRelated to torch.backends.cudnn, and CuDNN supportmodule: lossProblem is related to loss functiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0