[go: up one dir, main page]

Skip to content

Commit

Permalink
Use grad_out for cudnn CTC loss (#27039)
Browse files Browse the repository at this point in the history
Summary:
Using grad_out for CuDNN CTC loss fixes: #26797, #25833.

We also fix a cudnn incompatible change that surfaced during the testing: As of CuDNN 7.6 the semantics of the CTC loss gradients are different.
This leads us to disable CuDNN CTC for CuDNN < 7.6. To mitigate the impact on users, we convert the parameters for the native implementation if CuDNN isn't applicable (previously this would give an error.)
Pull Request resolved: #27039

Differential Revision: D17910815

Pulled By: ngimel

fbshipit-source-id: 465b33612d3402f10c355aa7026a7e1ffaef3073
  • Loading branch information
t-vi authored and facebook-github-bot committed Oct 15, 2019
1 parent 7e8420b commit f461184
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 47 deletions.
9 changes: 9 additions & 0 deletions aten/src/ATen/cudnn/Descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,15 @@ struct AT_CUDA_API CTCLossDescriptor
void set(cudnnDataType_t datatype) {
AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
}
#if CUDNN_VERSION >= 7600
void setEx(
cudnnDataType_t datatype,
cudnnLossNormalizationMode_t normMode,
cudnnNanPropagation_t gradMode) {
AT_CUDNN_CHECK(
cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
}
#endif
};

union Constant
Expand Down
38 changes: 14 additions & 24 deletions aten/src/ATen/native/LossCTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,36 +339,26 @@ Tensor ctc_loss_backward_cpu(const Tensor& grad, const Tensor& log_probs, const
// the gradient is implemented for _cudnn_ctc_loss (just in derivatives.yaml) and _ctc_loss and this function has automatic gradients
// it also handles the reduction if desired
Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
auto& ctx = at::globalContext();

bool use_cudnn =
detail::getCUDAHooks().compiledWithCuDNN() &&
(detail::getCUDAHooks().versionCuDNN() >= 7000) &&
ctx.userEnabledCuDNN() &&
(BLANK == 0) && (targets.dim()==1) &&
(log_probs.scalar_type() == at::kFloat) &&
(targets.scalar_type() == at::kInt) &&
(log_probs.type().backend() == Backend::CUDA);

if (use_cudnn) {
// we don't know that input_lengths and target_lengths have the same size (they should, but we didn't check yet)
int64_t max_input_length = log_probs.size(0);
for (size_t b = 0; b < input_lengths.size(); b++) {
use_cudnn &= (input_lengths[b] == max_input_length);
}
for (size_t b = 0; b < target_lengths.size(); b++) {
// target length < 256 is documented, but we see illegal memory accesses when target lengths > input lengths for CuDNN
use_cudnn &= (target_lengths[b] <= 256) & (target_lengths[b] <= input_lengths[b]);
}
}
(log_probs.device().type() == at::kCUDA) &&
at::_use_cudnn_ctc_loss(
log_probs, targets, input_lengths, target_lengths, BLANK);

Tensor res;
if (use_cudnn) {
// non-deterministic ctc loss on cudnn disabled due to inconsistent results
// see: https://github.com/pytorch/pytorch/issues/21680
res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, /*deterministic=*/true, zero_infinity));
} else {
res = std::get<0>(at::_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, zero_infinity));
// if the targets are on CPU (which you need for CuDNN, let's move them to
// GPU as a service for the user)
res = std::get<0>(at::_ctc_loss(
log_probs,
targets.to(log_probs.device(), kLong),
input_lengths,
target_lengths,
BLANK,
zero_infinity));
if (zero_infinity) {
res = at::where(res == Scalar(std::numeric_limits<double>::infinity()), at::zeros({}, res.options()), res);
}
Expand All @@ -388,8 +378,8 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, const Tensor& in
TORCH_CHECK(isIntegralType(input_lengths.scalar_type(), /*includeBool=*/false), "input_lengths must be integral");
TORCH_CHECK(isIntegralType(target_lengths.scalar_type(), /*includeBool=*/false), "target_lengths must be integral");

Tensor ilc = input_lengths.toType(kLong).toBackend(Backend::CPU).contiguous();
Tensor tlc = target_lengths.toType(kLong).toBackend(Backend::CPU).contiguous();
Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
return at::native::ctc_loss(log_probs, targets, il, tl, BLANK, reduction, zero_infinity);
Expand Down
91 changes: 71 additions & 20 deletions aten/src/ATen/native/cudnn/LossCTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@
#include <ATen/cudnn/Descriptors.h>
#endif


#if !AT_CUDNN_ENABLED()
#if (!AT_CUDNN_ENABLED()) || (CUDNN_VERSION < 7600)

namespace at { namespace native {

// See Note [ATen preprocessor philosophy]

bool _use_cudnn_ctc_loss(
const Tensor& log_probs,
const Tensor& targets,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
int64_t BLANK) {
return false;
}

std::tuple<Tensor, Tensor> _cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool deterministic, bool zero_infinity) {
AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 7 support");
}
Expand All @@ -29,9 +37,35 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss(const Tensor& log_probs, const Tensor

namespace at { namespace native {

namespace {

} // namespace
bool _use_cudnn_ctc_loss(
const Tensor& log_probs,
const Tensor& targets,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
int64_t BLANK) {
auto& ctx = at::globalContext();

bool use_cudnn = ctx.userEnabledCuDNN() && (BLANK == 0) &&
(targets.dim() == 1) && (log_probs.scalar_type() == at::kFloat) &&
(targets.scalar_type() == at::kInt) &&
(log_probs.device().type() == at::kCUDA);

if (use_cudnn) {
// we don't know that input_lengths and target_lengths have the same size
// (they should, but we didn't check yet)
int64_t max_input_length = log_probs.size(0);
for (size_t b = 0; b < input_lengths.size(); b++) {
use_cudnn &= (input_lengths[b] == max_input_length);
}
for (size_t b = 0; b < target_lengths.size(); b++) {
// target length < 256 is documented, but we see illegal memory accesses
// when target lengths > input lengths for CuDNN
use_cudnn &=
(target_lengths[b] <= 256) & (target_lengths[b] <= input_lengths[b]);
}
}
return use_cudnn;
}

std::tuple<Tensor, Tensor> _cudnn_ctc_loss(const Tensor& log_probs_t, const Tensor& targets_t, IntArrayRef input_lengths_, IntArrayRef target_lengths_, int64_t BLANK, bool deterministic, bool zero_infinity) {
(void)zero_infinity; // only used for backward
Expand Down Expand Up @@ -62,28 +96,45 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss(const Tensor& log_probs_t, const Tens

cudnnCTCLossAlgo_t algo = (deterministic ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC);

Tensor probs = log_probs->softmax(2);
TensorDescriptor probs_desc{probs};
Tensor grad = at::empty_like(probs);
TensorDescriptor grad_desc{grad};

CTCLossDescriptor ctc_loss_desc;
ctc_loss_desc.set(CUDNN_DATA_FLOAT);

size_t workspace_size;
AT_CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize(handle, probs_desc.desc(), grad_desc.desc(),
targets->data_ptr<int>(), target_lengths.data(), input_lengths.data(),
algo, ctc_loss_desc.desc(), &workspace_size));
// so the CuDNN gradient semantics have changed between 7.1 and 7.6,
// this is CuDNN 7.6 only, see PyTorch 1.2 for older CuDNN.
ctc_loss_desc.setEx(
CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN);
TensorDescriptor log_probs_desc{log_probs_t};
Tensor grad = at::empty_like(log_probs_t);
TensorDescriptor grad_desc{grad};

size_t workspace_size;
AT_CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize(
handle,
log_probs_desc.desc(),
grad_desc.desc(),
targets->data_ptr<int>(),
target_lengths.data(),
input_lengths.data(),
algo,
ctc_loss_desc.desc(),
&workspace_size));

Tensor workspace = at::empty(workspace_size, log_probs->options().dtype(kByte));
Tensor costs = at::empty({log_probs->size(1)}, log_probs->options());

AT_CUDNN_CHECK(cudnnCTCLoss(handle, probs_desc.desc(), probs.data_ptr(),
targets->data_ptr<int>(), target_lengths.data(), input_lengths.data(),
costs.data_ptr(), grad_desc.desc(), grad.data_ptr(), algo,
ctc_loss_desc.desc(), workspace.data_ptr(), workspace_size));

AT_CUDNN_CHECK(cudnnCTCLoss(
handle,
log_probs_desc.desc(),
log_probs_t.data_ptr(),
targets->data_ptr<int>(),
target_lengths.data(),
input_lengths.data(),
costs.data_ptr(),
grad_desc.desc(),
grad.data_ptr(),
algo,
ctc_loss_desc.desc(),
workspace.data_ptr(),
workspace_size));
return std::make_tuple(costs, grad);
}

Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
variants: method
supports_named_tensor: True


- func: _use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool
use_c10_dispatcher: unboxed_only
dispatch:
CUDA: _use_cudnn_ctc_loss

- func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
use_c10_dispatcher: unboxed_only
dispatch:
Expand Down
27 changes: 26 additions & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
S)
from common_device_type import (instantiate_device_type_tests, skipCUDAIfRocm,
onlyCUDA, dtypes, dtypesIfCUDA,
deviceCountAtLeast)
deviceCountAtLeast, skipCUDAIfCudnnVersionLessThan)

# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
Expand Down Expand Up @@ -3732,6 +3732,31 @@ def ctc_after_softmax(x):

gradcheck(ctc_after_softmax, [x], nondet_tol=1e-7)

@onlyCUDA
@skipCUDAIfRocm
@skipCUDAIfCudnnVersionLessThan(7600)
def test_ctc_loss_cudnn(self, device):
batch_size = 16
input_length = 30
num_labels = 101
target_length = 15
targets = torch.randint(1, num_labels, (batch_size * target_length,),
device='cuda', dtype=torch.long)
log_probs = torch.log_softmax(torch.randn(input_length, batch_size, num_labels, device='cuda', dtype=torch.float), 2)
log_probs.requires_grad_()

input_lengths = batch_size * [input_length]
target_lengths = batch_size * [target_length]
grad_out = torch.randn(batch_size, device='cuda', dtype=torch.float)
with torch.backends.cudnn.flags(enabled=False):
loss_native = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
grad_native, = torch.autograd.grad(loss_native, log_probs, grad_out)
loss_cudnn = torch.nn.functional.ctc_loss(log_probs, targets.to('cpu', torch.int32),
input_lengths, target_lengths, reduction='none')
self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn))
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
self.assertEqual(grad_cudnn, grad_native, prec=1e-4)

@onlyCUDA
def test_free_unneeded_tensor(self, device):
x = torch.randn(2, 3, 10, 10, device=device, requires_grad=True)
Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8949,7 +8949,7 @@ def test_rnn_retain_variables(self, device, dtype):
self._test_rnn_retain_variables(device, dtype)

@onlyCUDA
@skipCUDAIfCudnnVersionLessThan(7000)
@skipCUDAIfCudnnVersionLessThan(7600)
def test_CTCLoss_cudnn(self, device):
target_lengths = [30, 25, 20]
input_lengths = [50, 50, 50]
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,7 @@

# cudnn
- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
log_probs: "zero_infinity ? where(result0.unsqueeze(0).unsqueeze(2) == 0, zeros_like(result1), result1) : result1"
log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity)

- name: cudnn_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
self, weight, bias: cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask)
Expand Down
10 changes: 10 additions & 0 deletions tools/autograd/templates/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2453,6 +2453,16 @@ Tensor index_backward(Tensor zeros_like_self, TensorList indices, const Tensor&
return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
}

Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) {
if (zero_infinity) {
return at::where(
loss.unsqueeze(0).unsqueeze(2) == 0,
at::zeros({0}, raw_grad.options()),
raw_grad * grad_out.unsqueeze(0).unsqueeze(2));
} else {
return raw_grad * grad_out.unsqueeze(0).unsqueeze(2);
}
}

} // anonymous namespace

Expand Down

0 comments on commit f461184

Please sign in to comment.