8000 Update base for Update on "Remove det_singular OpInfo" · pytorch/pytorch@8dbd7b2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8dbd7b2

Browse files
committed
Update base for Update on "Remove det_singular OpInfo"
Fixes #93045 #93044 From previous discussion #93045 (comment) the resolution is that we're okay with removing this. Some older attempts: - #102581 - #109249 cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
1 parent 4a44a70 commit 8dbd7b2

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

test/test_autograd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12466,7 +12466,10 @@ def test_disallow_nesting(self):
1246612466

1246712467
def test_inplace_foreach(self):
1246812468
with torch.autograd.graph.allow_mutation_on_saved_tensors():
12469-
a = [torch.tensor(1., requires_grad=True), torch.tensor(1., requires_grad=True)]
12469+
a = [
12470+
torch.tensor(1., requires_grad=True),
12471+
torch.tensor(1., requires_grad=True)
12472+
]
1247012473
b = torch._foreach_exp(a)
1247112474
torch._foreach_add_(b, 1)
1247212475
(b[0] + b[1]).backward()

torch/nested/_internal/ops.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -585,14 +585,8 @@ def linear_backward_default(func, *args, **kwargs):
585585
input_2d = inp._values.reshape(-1, weight.size(1))
586586
dw = torch.matmul(grad_2d.t(), input_2d)
587587
if output_mask[2]:
588-
# Sum over all but the last dim to get a 1D bias grad. We cannot
589-
# rely on the autograd engine to reduce for us, because returning a
590-
# tensor aliasing the input would violate the aten signature annotation
591-
reduce_dims = tuple(range(grad_output._values.ndim - 1))
592-
if reduce_dims == ():
593-
db = grad_output._values.clone()
594-
else:
595-
db = torch.sum(grad_output._values, reduce_dims, keepdim=False)
588+
# NB: autograd engine will sum over all but the last dim to get a 1D bias grad.
589+
db = grad_output._values.detach()
596590
return (ds, dw, db)
597591

598592

0 commit comments

Comments
 (0)
0