-
Notifications
You must be signed in to change notification settings - Fork 24.2k
torch.utils.checkpoint preserves torch function mode stack during recompute #148023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/soulitzer/353/base
Are you sure you want to change the base?
Conversation
…ompute [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148023
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 12 New Failures, 1 Cancelled JobAs of commit 9652102 with merge base e57cdb8 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… during recompute" [ghstack-poisoned]
… during recompute" Fixes #147995 [ghstack-poisoned]
… during recompute" Fixes #147995 TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled. We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail. [ghstack-poisoned]
torch/utils/checkpoint.py
Outdated
with ( | ||
device_autocast_ctx, # type: ignore[attr-defined] | ||
torch.amp.autocast("cpu", **cpu_autocast_kwargs), | ||
_apply_torch_function_mode_stack(torch_function_mode_stack), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question - during the recompute phase, are we effectively rerunning the user's forward function (and all of there torch.*
ops)? Or is the AC code here doing something different, like capturing all of the ATen ops that is witnessed during the forward, and replaying those. If it's the latter, I'm not sure if it will work as cleanly with a TorchFunctionMode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in eager, its the former - we're just rerunning the user's forward function
for compile, I'm not sure what the status of TorchFunctionMode is with compile generally (cc @mlazos), but assuming that the TorchFunctionMode is inlined through by dynamo, I guess the subgraph that the HOP applies the is_recompute annotations to should include the TorchFunctionMode logic baked in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep I think this is right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is correct @soulitzer
… during recompute" Fixes #147995 TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled. We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail. [ghstack-poisoned]
… during recompute" Fixes #147995 TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled. We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you describe why this is the right thing to do?
In a world where AC's semantics are to "rerun the forward" its possible to break the invariant that the saved The responsibility to preserve the invariant is split between the user and the AC. One one hand, the user should not do random control flow depending on global state. On the other hand, AC will handle the composing of built in features like RNG and autocast. TorchFunctionMode is an interesting case where its a built-in feature but its also an extension point where the user can add custom logic, so there's no guarantee whether they want to reenable the stack or not depending on their use case. But two reasons for reenabling the stack by default are:
|
Discussed this with @albanD and we may want to support this more generally by just stashing/restoring the TLS state. The blast radius of this PR becomes a bit larger, but would allow us to also support TorchDispatchMode and clean up the currently manual handling around autocast + no_grad. |
… during recompute" Fixes #147995 TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled. [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Fixes #147995
TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.