-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Checkpoint doesn't work with torch_function if torch_function change tensor metadata #147995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc @soulitzer - it looks like the TorchFunctionMode ends up violating some AC conditions because it changes the metadata of the inputs to the AC region before autograd sees them. Does this look like something that is fixable / workaround-able? Or just fundamentally not allowed |
sorta unrelated, but NJT handles SDPA via torch_function and clearly has problems when checkpointing is involved: import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
def func(x, y) -> None:
q = torch.matmul(x, y).unflatten(-1, [2, 5]).transpose(1, 2)
return F.scaled_dot_product_attention(q, q, q)
nt = torch.nested.nested_tensor([
torch.randn(2, 5),
torch.randn(3, 5),
torch.randn(4, 5),
], layout=torch.jagged, device="cuda", requires_grad=True)
dense = torch.randn(5, 10, device="cuda")
out = checkpoint(func, nt, dense, use_reentrant=False)
out.sum().backward()
The error message indicates that the NST logic is being exercised, which it shouldn't be. cc @soulitzer |
@bdhirsh @fegin It sounds like the issue is that during the initial forward checkpoint runs with the TorchFunctionMode (which does a transpose), but during the recompute we are not running with the TorchFunctionMode - this causes the transpose to no longer be done during recompute and leading to the shape mismatch. One workaround here is to activate the TorchFunctionMode manually inside the function you are checkpointing def func(x, y) -> None:
with DistributeFunction():
return torch.matmul(x, y) The right fix is probably for checkpoint to detect the current modes enabled and during recompute reenable those modes though. @jbschlosser Yeah unfortunately the default determinism logic assumes .sizes() is supported on tensor. You should be able to bypass this by setting the |
Could it be that we're going into the math path somehow and doing the conversion from NJT to NST |
@soulitzer I guess this is a fundamental limitation then. We won't be able to insert the @drisspg Do you think there are other ways to let us install hooks to SDPA without using monkey patch? |
If this was true do you think there would still be a problem? |
If checkpoint can detect the current modes and appropriately update the saved tensors, then it should work well. Not sure how hard to achieve this though. |
Shouldn't be too hard |
… during recompute" Fixes #147995 [ghstack-poisoned]
… during recompute" Fixes #147995 TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled. We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail. [ghstack-poisoned]
… during recompute" Fixes #147995 TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled. We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail. [ghstack-poisoned]
… during recompute" Fixes #147995 TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled. We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail. [ghstack-poisoned]
… during recompute" Fixes #147995 TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled. [ghstack-poisoned]
While we prefer not use monkey patching to dispatch SDPA, TorchFunctionMode is currently not compatible with selective activation checkpointing (#147995). This PR adds `TorchFunctionMode` to CP code and make it configurable. Pull Request resolved: #147902 Approved by: https://github.com/XilunWu
🐛 Describe the bug
We are tying to use
TorchFunctionMode
to convert the input tensors of SDPA to DTensor (if they are not). Unfortunately this approach fails. Digging into the detail, this seems to be a fundamental limitation of checkpoint as checkpoint is not aware of__torch_function__
. Below is a minimal repro which utilizes__torch_function__
to reshape the input tensors.Checkpoint complains metadata mismatch:
Is there any way to make this
__torch_function__
work with Checkpoint?Versions
nightly
cc @soulitzer @hameerabbasi @rgommers @ezyang
The text was updated successfully, but these errors were encountered: