8000 Update on "torch.utils.checkpoint preserves torch function mode stack… · pytorch/pytorch@9d0b4a7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d0b4a7

Browse files
committed
Update on "torch.utils.checkpoint preserves torch function mode stack during recompute"
Fixes #147995 [ghstack-poisoned]
1 parent db7ffdc commit 9d0b4a7

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

test/test_autograd.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7518,33 +7518,37 @@ def forward(self, dict_input):
75187518
def test_checkpointing_preserves_torch_function_mode_stack(self):
75197519
log = []
75207520

7521-
class Func1(TorchFunctionMode):
7522-
def __torch_function__(self, func, types, args, kwargs=None):
7523-
kwargs = {} if kwargs is None else kwargs
7524-
log.append("func1")
7525-
return func(*args, **kwargs)
7521+
def get_mode_class(n):
7522+
class Func(TorchFunctionMode):
7523+
def __torch_function__(self, func, types, args, kwargs=None):
7524+
kwargs = {} if kwargs is None else kwargs
7525+
log.append(f"mode{n}")
7526+
return func(*args, **kwargs)
75267527

7527-
class Func2(TorchFunctionMode):
7528-
def __torch_function__(self, func, types, args, kwargs=None):
7529-
kwargs = {} if kwargs is None else kwargs
7530-
log.append("func2")
7531-
return func(*args, **kwargs)
7528+
return Func
7529+
7530+
Mode1 = get_mode_class(1)
7531+
Mode2 = get_mode_class(2)
7532+
Mode3 = get_mode_class(3)
75327533

75337534
def func(x):
75347535
return x.sin().cos()
75357536

7536-
with Func1():
7537-
with Func2():
7537+
def context_fn():
7538+
return Mode3(), Mode3()
7539+
7540+
with Mode1():
7541+
with Mode2():
75387542
a = torch.tensor(1., requires_grad=True)
75397543

75407544
log = []
7541-
out = checkpoint(func, a, use_reentrant=False)
7542-
self.assertTrue(log[0] == "func2" and log[1] == "func1")
7545+
out = checkpoint(func, a, use_reentrant=False, context_fn=context_fn)
7546+
self.assertTrue(log[-3] == "mode3" and log[-2] == "mode2" and log[-1] == "mode1")
7547+
75437548

75447549
log = []
75457550
out.backward()
7546-
7547-
self.assertTrue(log[0] == "func2" and log[1] == "func1")
7551+
self.assertTrue(log[-3] == "mode3" and log[-2] == "mode2" and log[1] == "mode1")
75487552

75497553
def test_callback_adds_callback(self):
75507554
called = [0]

torch/utils/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1521,8 +1521,8 @@ def recompute_fn(*inputs):
15211521
with (
15221522
device_autocast_ctx, # type: ignore[attr-defined]
15231523
torch.amp.autocast("cpu", **cpu_autocast_kwargs),
1524+
_apply_torch_function_mode_stack(torch_function_mode_stack),
15241525
recompute_context,
1525-
_apply_torch_function_mode_stack(torch_function_mode_stack)
15261526
):
15271527
fn(*args, **kwargs)
15281528

0 commit comments

Comments
 (0)
0