8000 [CUDA][cuDNN] Fix handling of `CPU` side input and target length tens… · pytorch/pytorch@cecfc7d · GitHub
[go: up one dir, main page]

Skip to content

Commit cecfc7d

Browse files
eqypytorchmergebot
authored andcommitted
[CUDA][cuDNN] Fix handling of CPU side input and target length tensors in CTCLoss (#152745)
#128271 migrated to cuDNN V8 CTCLoss which expects input and target length tensors to be on `CUDA` rather than `CPU` without adding the logic to account for the edge case of them being on `CPU` see also #152421 Pull Request resolved: #152745 Approved by: https://github.com/Skylion007
1 parent 773a91c commit cecfc7d

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

aten/src/ATen/native/cudnn/LossCTC.cpp

+20-4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ bool _use_cudnn_ctc_loss_tensor(
151151
}
152152
}
153153
} else {
154+
if (target_lengths.device().type() != at::kCUDA ||
155+
input_lengths.device().type() != at::kCUDA) {
156+
TORCH_CHECK(
157+
false,
158+
"CTCLoss cannot be graph captured with CPU length tensors. "
159+
"Move CPU length tensors to GPU memory to enable graph capture.")
160+
}
154161
at::_assert_async(at::lt(input_lengths.max(), 256));
155162
at::_assert_async(at::le(target_lengths, input_lengths).all());
156163
}
@@ -253,9 +260,18 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
253260
bool deterministic,
254261
bool zero_infinity) {
255262
Tensor targets_t_ = targets_t;
263+
Tensor input_lengths_ = input_lengths;
264+
Tensor target_lengths_ = target_lengths;
256265
if (targets_t.device().type() == at::kCPU) {
257266
targets_t_ = targets_t.to(Device(at::kCUDA));
258267
}
268+
if (input_lengths.device().type() == at::kCPU) {
269+
input_lengths_ = input_lengths.to(Device(at::kCUDA));
270+
}
271+
if (input_lengths.device().type() == at::kCPU) {
272+
target_lengths_ = target_lengths.to(Device(at::kCUDA));
273+
}
274+
259275
const CheckedFrom c = "cudnn_ctc_loss";
260276
const TensorArg log_probs{log_probs_t, "log_probs", 1};
261277
const TensorArg targets{targets_t_, "targets", 2};
@@ -268,9 +284,9 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
268284
checkBackend(c, {*targets}, Backend::CUDA);
269285
const auto batch_size = log_probs->size(1);
270286
int64_t input_lengths_size =
271-
input_lengths.sizes().size() ? input_lengths.size(0) : 1;
287+
input_lengths_.sizes().size() ? input_lengths_.size(0) : 1;
272288
int64_t target_lengths_size =
273-
target_lengths.sizes().size() ? target_lengths.size(0) : 1;
289+
target_lengths_.sizes().size() ? target_lengths_.size(0) : 1;
274290
TORCH_CHECK(
275291
input_lengths_size == batch_size,
276292
"input_lengths needs to have size to match batch_size");
@@ -319,8 +335,8 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
319335
log_probs_desc.desc(),
320336
log_probs_t.data_ptr(),
321337
targets_t_.data_ptr<int>(),
322-
target_lengths.data_ptr<int>(),
323-
input_lengths.data_ptr<int>(),
338+
target_lengths_.data_ptr<int>(),
339+
input_lengths_.data_ptr<int>(),
324340
costs.data_ptr(),
325341
grad_desc.desc(),
326342
grad.data_ptr(),

test/test_nn.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -11523,7 +11523,7 @@ def test_ctc_loss_cudnn(self, device):
1152311523

1152411524
@onlyCUDA
1152511525
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
11526-
def test_ctc_loss_cudnn_tensor(self, device):
11526+
def test_ctc_loss_cudnn_tensor_cuda(self):
1152711527
batch_size = 16
1152811528
input_length = 30
1152911529
num_labels = 101
@@ -11549,6 +11549,36 @@ def test_ctc_loss_cudnn_tensor(self, device):
1154911549
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
1155011550
self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
1155111551

11552+
@onlyCUDA
11553+
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
11554+
def test_ctc_loss_cudnn_tensor_cpu_length_cuda(self):
11555+
# batch size
11556+
N = 50
11557+
# audio length
11558+
T = 100
11559+
# text dimension
11560+
C = 80
11561+
# max text length
11562+
S = 10
11563+
11564+
prob_device = torch.device("cuda")
11565+
other_device = torch.device("cpu")
11566+
other_dtype = torch.int32
11567+
11568+
log_probs = torch.randn(T, N, C).log_softmax(2).to(prob_device)
11569+
11570+
input_lengths = torch.full((N,), T, dtype=other_dtype).to(other_device)
11571+
target_lengths = torch.randint(low=1, high=S, size=(N,), dtype=other_dtype).to(other_device)
11572+
targets = torch.randint(low=0, high=C, size=(sum(target_lengths),), dtype=other_dtype).to(other_device)
11573+
11574+
ctc_loss = torch.nn.functional.ctc_loss(
11575+
log_probs= 6D40 log_probs,
11576+
targets=targets,
11577+
input_lengths=input_lengths,
11578+
target_lengths=target_lengths,
11579+
reduction="sum",
11580+
)
11581+
1155211582
@expectedFailureMPS
1155311583
def test_ctc_loss_error(self, device):
1155411584
log_probs = torch.rand(0, 0, 4, device=device)

0 commit comments

Comments
 (0)
0