8000 torch.utils.checkpoint preserves torch function mode stack during rec… · pytorch/pytorch@31e31f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 31e31f7

Browse files
committed
torch.utils.checkpoint preserves torch function mode stack during recompute
ghstack-source-id: cbb3741 Pull Request resolved: #148023
1 parent 9035fd1 commit 31e31f7

File tree

3 files changed

+63
-1
lines changed

3 files changed

+63
-1
lines changed

test/test_autograd.py

+41
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
TestCase,
8181
xfailIfTorchDynamo,
8282
)
83+
from torch.overrides import TorchFunctionMode
8384
from torch.utils._mode_utils import no_dispatch
8485
from torch.utils._python_dispatch import TorchDispatchMode
8586
from torch.utils.checkpoint import (
@@ -7514,6 +7515,46 @@ def forward(self, dict_input):
75147515
):
75157516
self.assertEqual(param.grad, checkpoint_param.grad)
75167517

7518+
@xfailIfTorchDynamo
7519+
def test_checkpointing_preserves_torch_function_mode_stack(self):
7520+
log = []
7521+
7522+
def get_mode_class(n):
7523+
class Func(TorchFunctionMode):
7524+
def __torch_function__(self, func, types, args, kwargs=None):
7525+
kwargs = {} if kwargs is None else kwargs
7526+
log.append(f"mode{n}")
7527+
return func(*args, **kwargs)
7528+
7529+
return Func
7530+
7531+
Mode1 = get_mode_class(1)
7532+
Mode2 = get_mode_class(2)
7533+
Mode3 = get_mode_class(3)
7534+
7535+
def func(x):
7536+
return x.sin().cos()
7537+
7538+
def context_fn():
7539+
return Mode3(), Mode3()
7540+
7541+
with Mode1():
7542+
with Mode2():
7543+
a = torch.tensor(1.0, requires_grad=True)
7544+
7545+
log = []
7546+
out = checkpoint(func, a, use_reentrant=False, context_fn=context_fn)
7547+
self.assertTrue(
7548+
log[-3] == "mode3" and log[-2] == "mode2" and log[-1] == "mode1"
7549+
)
7550+
7551+
7552+
log = []
7553+
out.backward()
7554+
self.assertTrue(
7555+
log[-3] == "mode3" and log[-2] == "mode2" and log[1] == "mode1"
7556+
)
7557+
75177558
def test_callback_adds_callback(self):
75187559
called = [0]
75197560

torch/overrides.py

+12
Original file line numberDiff line numberDiff line change
@@ -2073,6 +2073,18 @@ def _pop_mode():
20732073
return old
20742074

20752075

2076+
@contextlib.contextmanager
2077+
def _apply_torch_function_mode_stack(mode_stack):
2078+
stack_len = len(mode_stack)
2079+
try:
2080+
for mode in mode_stack:
2081+
_push_mode(mode)
2082+
yield
2083+
finally:
2084+
for _ in range(stack_len):
2085+
_pop_mode()
2086+
2087+
20762088
@contextlib.contextmanager
20772089
def _pop_mode_temporarily():
20782090
old = _pop_mode()

torch/utils/checkpoint.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torch.utils._pytree import tree_map
1616
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
1717
from torch.utils._python_dispatch import TorchDispatchMode
18+
from torch.overrides import _get_current_function_mode_stack, _apply_torch_function_mode_stack
19+
1820

1921
__all__ = [
2022
"checkpoint",
@@ -1496,6 +1498,8 @@ def _checkpoint_without_reentrant_generator(
14961498
had_device_in_fwd = True
14971499
fwd_devices, fwd_device_states = get_device_states(*args)
14981500

1501+
torch_function_mode_stack = _get_current_function_mode_stack()
1502+
14991503
def recompute_fn(*inputs):
15001504
kwargs, *args = inputs
15011505
# This will be called later during recomputation. This wrapping enables
@@ -1514,7 +1518,12 @@ def recompute_fn(*inputs):
15141518
device_autocast_ctx = torch.amp.autocast(
15151519
device_type=device_type, **device_autocast_kwargs
15161520
) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext()
1517-
with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
1521+
with (
1522+
device_autocast_ctx, # type: ignore[attr-defined]
1523+
torch.amp.autocast("cpu", **cpu_autocast_kwargs),
1524+
_apply_torch_function_mode_stack(torch_function_mode_stack),
1525+
recompute_context,
1526+
):
15181527
fn(*args, **kwargs)
15191528

15201529
new_frame = _CheckpointFrame(

0 commit comments

Comments
 (0)
0