8000 [CUDA] Fix missing `__syncthreads` in MultiMarginLoss backward (#158994) · pytorch/pytorch@8573a2b · GitHub
[go: up one dir, main page]

Skip to content

Commit 8573a2b

Browse files
eqymalfet
authored andcommitted
[CUDA] Fix missing __syncthreads in MultiMarginLoss backward (#158994)
Turns out issue in #158921 is detectable with a simple unit test and adding the missing sync fixes it Pull Request resolved: #158994 Approved by: https://github.com/malfet, https://github.com/Skylion007 Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1 parent 13398da commit 8573a2b

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

aten/src/ATen/native/cuda/MultiMarginLoss.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ __global__ void MultiMarginLoss_backward_kernel(
121121
gradInput_k[target_k] = static_cast<scalar_t>(gradInput_target_k);
122122
}
123123

124+
__syncthreads();
124125
for (int i=i_start; i<i_end; i+= i_step) {
125126
gradInput_k[i] *= * gradOutput_k;
126127
}

test/test_nn.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9291,6 +9291,25 @@ def test_MarginLoss_empty(self, device, dtype):
92919291
y = torch.ones(10, 0, device=device).type(torch.long)
92929292
mod(x, y)
92939293

9294+
@onlyCUDA
9295+
@dtypes(torch.float, torch.double)
9296+
def test_MarginLoss_race(self, device, dtype):
9297+
loss = torch.nn.MultiMarginLoss().to(device)
9298+
batch = 1
9299+
classes = 128
9300+
x = torch.randn(batch, classes, requires_grad=True, device=device, dtype=dtype)
9301+
y = torch.randint(low=0, high=classes, size=(batch,), device=device, dtype=torch.long)
9302+
x_cpu = x.detach().clone().cpu()
9303+
y_cpu = y.detach().clone().cpu()
9304+
out = loss(x, y)
9305+
out.backward()
9306+
x_cpu = x.detach().clone().cpu()
9307+
x_cpu.requires_grad = True
9308+
y_cpu = y.detach().clone().cpu()
9309+
out_cpu = loss.cpu()(x_cpu, y_cpu)
9310+
out_cpu.backward()
9311+
self.assertEqual(x_cpu.grad, x.grad.cpu())
9312+
92949313
@onlyCUDA
92959314
def test_MarginLoss_warnings(self, device):
92969315
model = torch.nn.Linear(128, 22, device=device)

0 commit comments

Comments
 (0)
0