8000 Various fix for memory leak in test autograd and dataloader (#143323) · pytorch/pytorch@80a4239 · GitHub
[go: up one dir, main page]

Skip to content

Commit 80a4239

Browse files
albanDpytorchmergebot
authored andcommitted
Various fix for memory leak in test autograd and dataloader (#143323)
Pull Request resolved: #143323 Approved by: https://github.com/andrewkho, https://github.com/soulitzer ghstack dependencies: #143225
1 parent 84b91ce commit 80a4239

File tree

2 files changed

+64
-30
lines changed

2 files changed

+64
-30
lines changed

test/test_autograd.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ def backward(ctx, grad_output):
264264
self.assertExpected(x_grad_desc, "x_grad_desc")
265265
self.assertExpected(y_grad_desc, "y_grad_desc")
266266

267+
# Avoid leaking memory
268+
x.grad = None
269+
y.grad = None
270+
267271
def test_once_differentiable(self):
268272
class MyFunction(Function):
269273
@staticmethod
@@ -293,6 +297,10 @@ def backward(ctx, grad_output):
293297
"CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))",
294298
)
295299

300+
# Avoid leaking memory
301+
x.grad = None
302+
y.grad = None
303+
296304
def test_function_returns_input(self):
297305
class MyFunction(Function):
298306
@staticmethod
@@ -640,8 +648,8 @@ def fn(x):
640648
for g in should_not_execute:
641649
self.assertFalse(torch._C._will_engine_execute_node(g))
642650

643-
b.register_hook(fn)
644-
c.register_hook(fn)
651+
h1 = b.register_hook(fn)
652+
h2 = c.register_hook(fn)
645653

646654
# .backward(inputs=) is OK
647655
out = c.sum()
@@ -668,7 +676,7 @@ def fn(x):
668676
counter[0] += 1
669677
self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn))
670678

671-
b.register_hook(fn)
679+
h3 = b.register_hook(fn)
672680
counter[0] = 0
673681
torch.autograd.grad(b.sum(), (a,))
674682
self.assertEqual(counter[0], 1)
@@ -680,6 +688,11 @@ def fn(x):
680688
with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"):
681689
torch._C._will_engine_execute_node(out)
682690

691+
# Ensure we don't leak memory
692+
h1.remove()
693+
h2.remove()
694+
h3.remove()
695+
683696
def test_custom_function_vmap_defaults(self):
684697
class MySquare(Function):
685698
@staticmethod
@@ -899,6 +912,10 @@ def test_hessian_vector(self):
899912
self.assertEqual(x.grad, x_grad + x_hv)
900913
self.assertEqual(y.grad, y_grad + y_hv)
901914

915+
# Avoid leaking memory
916+
x.grad = None
917+
y.grad = None
918+
902919
def test_grad(self):
903920
x = torch.randn(2, 2, requires_grad=True)
904921
y = torch.randn(2, 2, requires_grad=True)
@@ -924,6 +941,10 @@ def test_grad(self):
924941
self.assertEqual(x.grad, x_grad)
925942
self.assertEqual(y.grad, y_grad)
926943

944+
# Avoid leaking memory
945+
x.grad = None
946+
y.grad = None
947+
927948
# Test that grad_outputs and outputs have the same shape
928949
grad_out = torch.ones(2)
929950
try:
@@ -1071,6 +1092,7 @@ def test_grad_fn_input_metadata(self):
10711092
layout=torch.jagged,
10721093
requires_grad=True,
10731094
)
1095+
10741096
nt_metadata = nt.clone().grad_fn._input_metadata[0]
10751097

10761098
self.assertIsInstance(nt_metadata.shape[1], torch.SymInt)
@@ -2209,16 +2231,21 @@ def fn2(grad):
22092231

22102232
b = torch.rand(3, 3, requires_grad=True)
22112233
out1, out2 = fn(b)
2212-
out1.register_hook(fn0)
2213-
out2.register_hook(fn1)
2234+
h1 = out1.register_hook(fn0)
2235+
h2 = out2.register_hook(fn1)
22142236
# node refers to two hook dicts
22152237
# out1 no longer no longer points to its old hook dict
22162238
out1.mul_(2)
22172239
# fn2 is registered to out1's new hook dict
2218-
out1.register_hook(fn2)
2240+
h3 = out1.register_hook(fn2)
22192241
(out1 + out2 * 3).sum().backward()
22202242
self.assertEqual(counts, [1, 1, 1])
22212243

2244+
# Avoid leaking memory
2245+
h1.remove()
2246+
h2.remove()
2247+
h3.remove()
2248+
22222249
def test_tensor_hooks_inplace_over_view(self):
22232250
# There might be a better UX here, but this is the way it is now
22242251
count = [0]
@@ -2484,6 +2511,11 @@ def test_backward_with_nonleaf_inputs(self):
24842511
)
24852512
self.assertIsNone(z.grad)
24862513

2514+
# Avoid leaking memory
2515+
x.grad = None
2516+
y.grad = None
2517+
x_nonleaf.grad = None
2518+
24872519
def test_dependent_backward(self):
24882520
x = torch.randn(10, requires_grad=True)
24892521
y = x**2
@@ -4445,6 +4477,7 @@ def hook(_):
44454477

44464478
def test_current_graph_task_execution_order(self):
44474479
predicted = [None]
4480+
all_hooks = []
44484481

44494482
def hook(_):
44504483
predicted[0] = torch._C._current_graph_task_execution_order()
@@ -4473,11 +4506,11 @@ def hook(t_):
44734506
return hook
44744507

44754508
for i, t in enumerate(tensors):
4476-
t.register_hook(get_hook(i))
4509+
all_hooks.append(t.register_hook(get_hook(i)))
44774510

44784511
# Basic example: single path
44794512
t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4480-
t.register_hook(hook)
4513+
all_hooks.append(t.register_hook(hook))
44814514
with torch.autograd.set_multithreading_enabled(False):
44824515
t.backward()
44834516
self.assertExpectedInline(
@@ -4494,7 +4527,7 @@ def hook(t_):
44944527
d = a.cos()
44954528
out = c * d
44964529
register_logging_hooks(a, b, c, d, out)
4497-
out.register_hook(hook)
4530+
all_hooks.append(out.register_hook(hook))
44984531
with torch.autograd.set_multithreading_enabled(False):
44994532
out.backward()
45004533
self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4506,7 +4539,7 @@ def hook(t_):
45064539
c = a.cos()
45074540
out = b * c
45084541
register_logging_hooks(a, b, c, out)
4509-
out.register_hook(hook)
4542+
all_hooks.append(out.register_hook(hook))
45104543
with torch.autograd.set_multithreading_enabled(False):
45114544
out.backward()
45124545
self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4519,7 +4552,7 @@ def hook(t_):
45194552
out2 = b.cos()
45204553
out3 = b.cos()
45214554
register_logging_hooks(a, b, out, out2, out3)
4522-
out3.register_hook(hook)
4555+
all_hooks.append(out3.register_hook(hook))
45234556
with torch.autograd.set_multithreading_enabled(False):
45244557
torch.autograd.grad((out, out3, out2), inputs=(a,))
45254558
self.assertExpectedInline(
@@ -4537,7 +4570,7 @@ def hook(t_):
45374570
b = a * 2
45384571
out = b.sin()
45394572
register_logging_hooks(a, b, out)
4540-
out.register_hook(hook)
4573+
all_hooks.append(out.register_hook(hook))
45414574
with torch.autograd.set_multithreading_enabled(False):
45424575
out.backward()
45434576
self.assertEqual(predicted[0], grad_fns(*actual))
@@ -4548,7 +4581,7 @@ def hook(t_):
45484581
b = a * 2
45494582
out = b.sin()
45504583
register_logging_hooks(a, b, out)
4551-
out.register_hook(hook)
4584+
all_hooks.append(out.register_hook(hook))
45524585
with torch.autograd.set_multithreading_enabled(False):
45534586
torch.autograd.grad((out,), inputs=(a, b))
45544587
self.assertEqual(
@@ -4567,7 +4600,7 @@ def hook(t_):
45674600
c = a * b
45684601
out = c.sin()
45694602
register_logging_hooks(a, b, c, out)
4570-
out.register_hook(hook)
4603+
all_hooks.append(out.register_hook(hook))
45714604
with torch.autograd.set_multithreading_enabled(False):
45724605
torch.autograd.grad((out,), inputs=(a,))
45734606
self.assertEqual(
@@ -4588,13 +4621,17 @@ def hook(t_):
45884621

45894622
# Errors when context manager not enabled
45904623
t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4591-
t.register_hook(hook)
4624+
all_hooks.append(t.register_hook(hook))
45924625
with self.assertRaisesRegex(
45934626
RuntimeError,
45944627
"expects the current backward to be executed with multithreading disabled",
45954628
):
45964629
t.backward()
45974630

4631+
# Avoid leaking memory
4632+
for h in all_hooks:
4633+
h.remove()
4634+
45984635
@skipIfWindows(msg="node name demangling inconsistent on windows")
45994636
def test_backward_hook_relative_ordering(self):
46004637
order = []
@@ -12927,7 +12964,7 @@ def hook(grads):
1292712964
else:
1292812965
self.assertEqual(res, grad_is_none)
1292912966

12930-
torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
12967+
handle = torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
1293112968

1293212969
out = (t2 * t3).sum()
1293312970

@@ -12976,6 +13013,8 @@ def backward_retain_graph(out, t2, t3):
1297613013
self.assertEqual(err_count[0], 1)
1297713014
self.assertEqual(res, [False, True, True, False])
1297813015

13016+
handle.remove()
13017+
1297913018
def test_multi_grad_any_hooks(self):
1298013019
# Multihooks should behave independently per execution of backward
1298113020
# Test that the hook fired the number of times we ran backward

test/test_utils.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -561,11 +561,6 @@ def test_infer_device_state_recursive_multi_cuda(self):
561561
class TestDataLoaderUtils(TestCase):
562562
MAX_TIMEOUT_IN_SECOND = 300
563563

564-
def setUp(self):
565-
super().setUp()
566-
self.dataset = torch.randn(5, 3, 3, 2)
567-
self.batch_size = 3
568-
569564
def test_random_seed(self):
570565
def run():
571566
dataloader = torch.utils.data.DataLoader(
@@ -584,12 +579,12 @@ def run():
584579
self.assertEqual(x1, x2)
585580

586581
def test_single_keep(self):
587-
# self.dataset is a Tensor here; technically not a valid input because
582+
# torch.rand(5, 3, 3, 2) is a Tensor here; technically not a valid input because
588583
# not a Dataset subclass, but needs to stay working so add ignore's
589584
# for type checking with mypy
590585
dataloader: DataLoader = DataLoader(
591-
self.dataset, # type: ignore[arg-type]
592-
batch_size=self.batch_size,
586+
torch.rand(5, 3, 3, 2), # type: ignore[arg-type]
587+
batch_size=3,
593588
num_workers=0,
594589
drop_last=False,
595590
)
@@ -598,8 +593,8 @@ def test_single_keep(self):
598593

599594
def test_single_drop(self):
600595
dataloader: DataLoader = DataLoader(
601-
self.dataset, # type: ignore[arg-type]
602-
batch_size=self.batch_size,
596+
torch.rand(5, 3, 3, 2), # type: ignore[arg-type]
597+
batch_size=3,
603598
num_workers=0,
604599
drop_last=True,
605600
)
@@ -611,8 +606,8 @@ def test_single_drop(self):
611606
)
612607
def test_multi_keep(self):
613608
dataloader: DataLoader = DataLoader(
614-
self.dataset, # type: ignore[arg-type]
615-
batch_size=self.batch_size,
609+
torch.rand(5, 3, 3, 2), # type: ignore[arg-type]
610+
batch_size=3,
616611
num_workers=2,
617612
drop_last=False,
618613
timeout=self.MAX_TIMEOUT_IN_SECOND,
@@ -622,8 +617,8 @@ def test_multi_keep(self):
622617

623618
def test_multi_drop(self):
624619
dataloader: DataLoader = DataLoader(
625-
self.dataset, # type: ignore[arg-type]
626-
batch_size=self.batch_size,
620+
torch.rand(5, 3, 3, 2), # type: ignore[arg-type]
621+
batch_size=3,
627622
num_workers=2,
628623
drop_last=True,
629624
timeout=self.MAX_TIMEOUT_IN_SECOND,

0 commit comments

Comments
 (0)
0