@@ -264,6 +264,10 @@ def backward(ctx, grad_output):
264
264
self.assertExpected(x_grad_desc, "x_grad_desc")
265
265
self.assertExpected(y_grad_desc, "y_grad_desc")
266
266
267
+ # Avoid leaking memory
268
+ x.grad = None
269
+ y.grad = None
270
+
267
271
def test_once_differentiable(self):
268
272
class MyFunction(Function):
269
273
@staticmethod
@@ -293,6 +297,10 @@ def backward(ctx, grad_output):
293
297
"CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))",
294
298
)
295
299
300
+ # Avoid leaking memory
301
+ x.grad = None
302
+ y.grad = None
303
+
296
304
def test_function_returns_input(self):
297
305
class MyFunction(Function):
298
306
@staticmethod
@@ -640,8 +648,8 @@ def fn(x):
640
648
for g in should_not_execute:
641
649
self.assertFalse(torch._C._will_engine_execute_node(g))
642
650
643
- b.register_hook(fn)
644
- c.register_hook(fn)
651
+ h1 = b.register_hook(fn)
652
+ h2 = c.register_hook(fn)
645
653
646
654
# .backward(inputs=) is OK
647
655
out = c.sum()
@@ -668,7 +676,7 @@ def fn(x):
668
676
counter[0] += 1
669
677
self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn))
670
678
671
- b.register_hook(fn)
679
+ h3 = b.register_hook(fn)
672
680
counter[0] = 0
673
681
torch.autograd.grad(b.sum(), (a,))
674
682
self.assertEqual(counter[0], 1)
@@ -680,6 +688,11 @@ def fn(x):
680
688
with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"):
681
689
torch._C._will_engine_execute_node(out)
682
690
691
+ # Ensure we don't leak memory
692
+ h1.remove()
693
+ h2.remove()
694
+ h3.remove()
695
+
683
696
def test_custom_function_vmap_defaults(self):
684
697
class MySquare(Function):
685
698
@staticmethod
@@ -899,6 +912,10 @@ def test_hessian_vector(self):
899
912
self.assertEqual(x.grad, x_grad + x_hv)
900
913
self.assertEqual(y.grad, y_grad + y_hv)
901
914
915
+ # Avoid leaking memory
916
+ x.grad = None
917
+ y.grad = None
918
+
902
919
def test_grad(self):
903
920
x = torch.randn(2, 2, requires_grad=True)
904
921
y = torch.randn(2, 2, requires_grad=True)
@@ -924,6 +941,10 @@ def test_grad(self):
924
941
self.assertEqual(x.grad, x_grad)
925
942
self.assertEqual(y.grad, y_grad)
926
943
944
+ # Avoid leaking memory
945
+ x.grad = None
946
+ y.grad = None
947
+
927
948
# Test that grad_outputs and outputs have the same shape
928
949
grad_out = torch.ones(2)
929
950
try:
@@ -1071,6 +1092,7 @@ def test_grad_fn_input_metadata(self):
1071
1092
layout=torch.jagged,
1072
1093
requires_grad=True,
1073
1094
)
1095
+
1074
1096
nt_metadata = nt.clone().grad_fn._input_metadata[0]
1075
1097
1076
1098
self.assertIsInstance(nt_metadata.shape[1], torch.SymInt)
@@ -2209,16 +2231,21 @@ def fn2(grad):
2209
2231
2210
2232
b = torch.rand(3, 3, requires_grad=True)
2211
2233
out1, out2 = fn(b)
2212
- out1.register_hook(fn0)
2213
- out2.register_hook(fn1)
2234
+ h1 = out1.register_hook(fn0)
2235
+ h2 = out2.register_hook(fn1)
2214
2236
# node refers to two hook dicts
2215
2237
# out1 no longer no longer points to its old hook dict
2216
2238
out1.mul_(2)
2217
2239
# fn2 is registered to out1's new hook dict
2218
- out1.register_hook(fn2)
2240
+ h3 = out1.register_hook(fn2)
2219
2241
(out1 + out2 * 3).sum().backward()
2220
2242
self.assertEqual(counts, [1, 1, 1])
2221
2243
2244
+ # Avoid leaking memory
2245
+ h1.remove()
2246
+ h2.remove()
2247
+ h3.remove()
2248
+
2222
2249
def test_tensor_hooks_inplace_over_view(self):
2223
2250
# There might be a better UX here, but this is the way it is now
2224
2251
count = [0]
@@ -2484,6 +2511,11 @@ def test_backward_with_nonleaf_inputs(self):
2484
2511
)
2485
2512
self.assertIsNone(z.grad)
2486
2513
2514
+ # Avoid leaking memory
2515
+ x.grad = None
2516
+ y.grad = None
2517
+ x_nonleaf.grad = None
2518
+
2487
2519
def test_dependent_backward(self):
2488
2520
x = torch.randn(10, requires_grad=True)
2489
2521
y = x**2
@@ -4445,6 +4477,7 @@ def hook(_):
4445
4477
4446
4478
def test_current_graph_task_execution_order(self):
4447
4479
predicted = [None]
4480
+ all_hooks = []
4448
4481
4449
4482
def hook(_):
4450
4483
predicted[0] = torch._C._current_graph_task_execution_order()
@@ -4473,11 +4506,11 @@ def hook(t_):
4473
4506
return hook
4474
4507
4475
4508
for i, t in enumerate(tensors):
4476
- t.register_hook(get_hook(i))
4509
+ all_hooks.append( t.register_hook(get_hook(i) ))
4477
4510
4478
4511
# Basic example: single path
4479
4512
t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4480
- t.register_hook(hook)
4513
+ all_hooks.append( t.register_hook(hook) )
4481
4514
with torch.autograd.set_multithreading_enabled(False):
4482
4515
t.backward()
4483
4516
self.assertExpectedInline(
@@ -4494,7 +4527,7 @@ def hook(t_):
4494
4527
d = a.cos()
4495
4528
out = c * d
4496
4529
register_logging_hooks(a, b, c, d, out)
4497
- out.register_hook(hook)
4530
+ all_hooks.append( out.register_hook(hook) )
4498
4531
with torch.autograd.set_multithreading_enabled(False):
4499
4532
out.backward()
4500
4533
self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4506,7 +4539,7 @@ def hook(t_):
4506
4539
c = a.cos()
4507
4540
out = b * c
4508
4541
register_logging_hooks(a, b, c, out)
4509
- out.register_hook(hook)
4542
+ all_hooks.append( out.register_hook(hook) )
4510
4543
with torch.autograd.set_multithreading_enabled(False):
4511
4544
out.backward()
4512
4545
self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4519,7 +4552,7 @@ def hook(t_):
4519
4552
out2 = b.cos()
4520
4553
out3 = b.cos()
4521
4554
register_logging_hooks(a, b, out, out2, out3)
4522
- out3.register_hook(hook)
4555
+ all_hooks.append( out3.register_hook(hook) )
4523
4556
with torch.autograd.set_multithreading_enabled(False):
4524
4557
torch.autograd.grad((out, out3, out2), inputs=(a,))
4525
4558
self.assertExpectedInline(
@@ -4537,7 +4570,7 @@ def hook(t_):
4537
4570
b = a * 2
4538
4571
out = b.sin()
4539
4572
register_logging_hooks(a, b, out)
4540
- out.register_hook(hook)
4573
+ all_hooks.append( out.register_hook(hook) )
4541
4574
with torch.autograd.set_multithreading_enabled(False):
4542
4575
out.backward()
4543
4576
self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4548,7 +4581,7 @@ def hook(t_):
4548
4581
b = a * 2
4549
4582
out = b.sin()
4550
4583
register_logging_hooks(a, b, out)
4551
- out.register_hook(hook)
4584
+ all_hooks.append( out.register_hook(hook) )
4552
4585
with torch.autograd.set_multithreading_enabled(False):
4553
4586
torch.autograd.grad((out,), inputs=(a, b))
4554
4587
self.assertEqual(
@@ -4567,7 +4600,7 @@ def hook(t_):
4567
4600
c = a * b
4568
4601
out = c.sin()
4569
4602
register_logging_hooks(a, b, c, out)
4570
- out.register_hook(hook)
4603
+ all_hooks.append( out.register_hook(hook) )
4571
4604
with torch.autograd.set_multithreading_enabled(False):
4572
4605
torch.autograd.grad((out,), inputs=(a,))
4573
4606
self.assertEqual(
@@ -4588,13 +4621,17 @@ def hook(t_):
4588
4621
4589
4622
# Errors when context manager not enabled
4590
4623
t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4591
- t.register_hook(hook)
4624
+ all_hooks.append( t.register_hook(hook) )
4592
4625
with self.assertRaisesRegex(
4593
4626
RuntimeError,
4594
4627
"expects the current backward to be executed with multithreading disabled",
4595
4628
):
4596
4629
t.backward()
4597
4630
4631
+ # Avoid leaking memory
4632
+ for h in all_hooks:
4633
+ h.remove()
4634
+
4598
4635
@skipIfWindows(msg="node name demangling inconsistent on windows")
4599
4636
def test_backward_hook_relative_ordering(self):
4600
4637
order = []
@@ -12927,7 +12964,7 @@ def hook(grads):
12927
12964
else:
12928
12965
self.assertEqual(res, grad_is_none)
12929
12966
12930
- torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
12967
+ handle = torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
12931
12968
12932
12969
out = (t2 * t3).sum()
12933
12970
@@ -12976,6 +13013,8 @@ def backward_retain_graph(out, t2, t3):
12976
13013
self.assertEqual(err_count[0], 1)
12977
13014
self.assertEqual(res, [False, True, True, False])
12978
13015
13016
+ handle.remove()
13017
+
12979
13018
def test_multi_grad_any_hooks(self):
12980
13019
# Multihooks should behave independently per execution of backward
12981
13020
# Test that the hook fired the number of times we ran backward
0 commit comments