8000 torch.utils.checkpoint preserves torch function mode stack during recompute by soulitzer · Pull Request #148023 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.utils.checkpoint preserves torch function mode stack during recompute #148023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: gh/soulitzer/353/base
Choose a base branch
from

Conversation

soulitzer
Copy link
Contributor
@soulitzer soulitzer commented Feb 26, 2025

Stack from ghstack (oldest at bottom):

Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine. Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

Copy link
pytorch-bot bot commented Feb 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148023

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 12 New Failures, 1 Cancelled Job

As of commit 9652102 with merge base e57cdb8 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

soulitzer added a commit that referenced this pull request Feb 26, 2025
@soulitzer soulitzer added release notes: autograd release notes category topic: bug fixes topic category labels Feb 26, 2025
@pytorch pytorch deleted a comment from github-actions bot Feb 26, 2025
soulitzer added a commit that referenced this pull request Feb 26, 2025
soulitzer added a commit that referenced this pull request Feb 26, 2025
… during recompute"


Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine.  Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail.

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Feb 27, 2025
with (
device_autocast_ctx, # type: ignore[attr-defined]
torch.amp.autocast("cpu", **cpu_autocast_kwargs),
_apply_torch_function_mode_stack(torch_function_mode_stack),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question - during the recompute phase, are we effectively rerunning the user's forward function (and all of there torch.* ops)? Or is the AC code here doing something different, like capturing all of the ATen ops that is witnessed during the forward, and replaying those. If it's the latter, I'm not sure if it will work as cleanly with a TorchFunctionMode

Copy link
Contributor Author
@soulitzer soulitzer Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in eager, its the former - we're just rerunning the user's forward function

for compile, I'm not sure what the status of TorchFunctionMode is with compile generally (cc @mlazos), but assuming that the TorchFunctionMode is inlined through by dynamo, I guess the subgraph that the HOP applies the is_recompute annotations to should include the TorchFunctionMode logic baked in

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep I think this is right

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct @soulitzer

… during recompute"


Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine.  Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail.

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Feb 27, 2025
… during recompute"


Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine.  Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail.

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Mar 3, 2025
Copy link
Collaborator
@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you describe why this is the right thing to do?

@soulitzer
Copy link
Contributor Author
soulitzer commented Mar 3, 2025

In a world where AC's semantics are to "rerun the forward" its possible to break the invariant that the saved
activations must be the same between forward and recompute due to control flow depending on global state changing what ops are run the second time around.

The responsibility to preserve the invariant is split between the user and the AC. One one hand, the user should not do random control flow depending on global state. On the other hand, AC will handle the composing of built in features like RNG and autocast.

TorchFunctionMode is an interesting case where its a built-in feature but its also an extension point where the user can add custom logic, so there's no guarantee whether they want to reenable the stack or not depending on their use case.

But two reasons for reenabling the stack by default are:

  • Most TorchFunctionModes are probably designed to apply some kind of program transform (applying a subclass like in the issue, decomposing this torch op into different ones to achieve a different memory/speed trade off). Under this case it is quite important to make sure the stack is reenabled and annoying to require the user to do something extra.
  • If the user is doing some kind of side effect in the TorchFunctionMode and they did not expect that side effect to execute a second time - I think they should just adjust their expectations because we've already defined AC's semantic to be "execute forward function again".

@soulitzer
Copy link
Contributor Author

Discussed this with @albanD and we may want to support this more generally by just stashing/restoring the TLS state. The blast radius of this PR becomes a bit larger, but would allow us to also support TorchDispatchMode and clean up the currently manual handling around autocast + no_grad.

… during recompute"


Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine.  Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Mar 10, 2025
@github-actions github-actions bot added the Stale label May 10, 2025
@pytorch pytorch deleted a comment from github-actions bot May 12, 2025
@soulitzer soulitzer added no-stale and removed Stale labels May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
no-stale release notes: autograd release notes category topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0