8000 Checkpoint doesn't work with torch_function if torch_function change tensor metadata · Issue #147995 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Checkpoint doesn't work with torch_function if torch_function change tensor metadata #147995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
fegin opened this issue Feb 26, 2025 · 9 comments
Assignees
Labels
module: activation checkpointing Related to activation checkpointing module: __torch_function__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@fegin
Copy link
Contributor
fegin commented Feb 26, 2025

🐛 Describe the bug

We are tying to use TorchFunctionMode to convert the input tensors of SDPA to DTensor (if they are not). Unfortunately this approach fails. Digging into the detail, this seems to be a fundamental limitation of checkpoint as checkpoint is not aware of __torch_function__. Below is a minimal repro which utilizes __torch_function__ to reshape the input tensors.

import torch
from torch.utils.checkpoint import checkpoint
from torch.overrides import TorchFunctionMode


def func(x, y) -> None:
    return torch.matmul(x, y)


class DistributeFunction(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        if kwargs is None:
            kwargs = {}

        if func != torch.matmul:
            return func(*args, **kwargs)

        a0 = args[0].reshape((-1, 128))
        a1 = args[1].reshape((128, -1))
        return func(a0, a1)


with DistributeFunction():
    a = torch.randn(64, 64)
    a.requires_grad = True
    out = checkpoint(func, a, a, use_reentrant=False)
    out.sum().backward()

Checkpoint complains metadata mismatch:

  File "/data/users/chienchin/mywork/pytorch/test.py", line 16, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/chienchin/mywork/pytorch/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/data/users/chienchin/mywork/pytorch/torch/autograd/__init__.py", line 353, in backward
    _engine_run_backward(
  File "/data/users/chienchin/mywork/pytorch/torch/autograd/graph.py", line 824, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass  File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 1129, in unpack_hook
    frame.check_recomputed_tensors_match(gid)
  File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 903, in check_recomputed_tensors_match
    raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 0:
saved metadata: {'shape': torch.Size([128, 32]), 'dtype': torch.float32, 'device': device(type='cpu')}
recomputed metadata: {'shape': torch.Size([64, 64]), 'dtype': torch.float32, 'device': device(type='cpu')}
tensor at position 1:
saved metadata: {'shape': torch.Size([32, 128]), 'dtype': torch.float32, 'device': device(type='cpu')}
recomputed metadata: {'shape': torch.Size([64, 64]), 'dtype': torch.float32, 'device': device(type='cpu')}

Is there any way to make this __torch_function__ work with Checkpoint?

Versions

nightly

cc @soulitzer @hameerabbasi @rgommers @ezyang

@fegin
Copy link
Contributor Author
fegin commented Feb 26, 2025

cc., @bdhirsh, @drisspg, @XilunWu

@bdhirsh
Copy link
Contributor
bdhirsh commented Feb 26, 2025

cc @soulitzer - it looks like the TorchFunctionMode ends up violating some AC conditions because it changes the metadata of the inputs to the AC region before autograd sees them.

Does this look like something that is fixable / workaround-able? Or just fundamentally not allowed

@jbschlosser
Copy link
Contributor
jbschlosser commented Feb 26, 2025

sorta unrelated, but NJT handles SDPA via torch_function and clearly has problems when checkpointing is involved:

import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

def func(x, y) -> None:
    q = torch.matmul(x, y).unflatten(-1, [2, 5]).transpose(1, 2)
    return F.scaled_dot_product_attention(q, q, q)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda", requires_grad=True)

dense = torch.randn(5, 10, device="cuda")

out = checkpoint(func, nt, dense, use_reentrant=False)
out.sum().backward()
  ...
  File ".../torch/nested/_internal/sdpa.py", line 843, in jagged_scaled_dot_product_attention
    attn_out = torch._scaled_dot_product_attention_math(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../torch/utils/checkpoint.py", line 1107, in pack_hook
    frame.x_metadatas.append(frame.metadata_fn(x))
                             ^^^^^^^^^^^^^^^^^^^^
  File ".../torch/utils/checkpoint.py", line 1035, in _default_meta_extractor
    "shape": x.shape,
             ^^^^^^^
RuntimeError: Internal error: NestedTensorImpl doesn't support sizes. Please file an issue.

The error message indicates that the NST logic is being exercised, which it shouldn't be.

cc @soulitzer

@soulitzer
Copy link
Contributor
soulitzer commented Feb 26, 2025

@bdhirsh @fegin It sounds like the issue is that during the initial forward checkpoint runs with the TorchFunctionMode (which does a transpose), but during the recompute we are not running with the TorchFunctionMode - this causes the transpose to no longer be done during recompute and leading to the shape mismatch.

One workaround here is to activate the TorchFunctionMode manually inside the function you are checkpointing

def func(x, y) -> None:
    with DistributeFunction():
        return torch.matmul(x, y)

The right fix is probably for checkpoint to detect the current modes enabled and during recompute reenable those modes though.

@jbschlosser Yeah unfortunately the default determinism logic assumes .sizes() is supported on tensor. You should be able to bypass this by setting the determinism_check flag, e.g. checkpoint(func, nt, dense, use_reentrant=False, determinism_check="none") though.

@soulitzer
Copy link
Contributor

The error message indicates that the NST logic is being exercised, which it shouldn't be.

Could it be that we're going into the math path somehow and doing the conversion from NJT to NST

@fegin
Copy link
Contributor Author
fegin commented Feb 26, 2025

@soulitzer I guess this is a fundamental limitation then. We won't be able to insert the TorchFunctionMode inside what is checkpoint because that is user model code. We can only wrap the model with TorchFunctionMode.

@drisspg Do you think there are other ways to let us install hooks to SDPA without using monkey patch?

@drisspg
Copy link
Contributor
drisspg commented Feb 26, 2025

@fegin

The right fix is probably for checkpoint to detect the current modes enabled and during recompute reenable those modes though.

If this was true do you think there would still be a problem?

@fegin
Copy link
Contributor Author
fegin commented Feb 26, 2025

If checkpoint can detect the current modes and appropriately update the saved tensors, then it should work well. Not sure how hard to achieve this though.

@soulitzer
Copy link
Contributor

Shouldn't be too hard

@soulitzer soulitzer self-assigned this Feb 26, 2025
soulitzer added a commit that referenced this issue Feb 26, 2025
soulitzer added a commit that referenced this issue Feb 27, 2025
… during recompute"


Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine.  Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail.

[ghstack-poisoned]
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 27, 2025
soulitzer added a commit that referenced this issue Feb 27, 2025
8000
… during recompute"


Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine.  Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail.

[ghstack-poisoned]
soulitzer added a commit that referenced this issue Mar 3, 2025
… during recompute"


Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine.  Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

We should still fix TorchDispatchMode though, because even if the user doesn't keep the same mode enabled on the .backward() call, checkpoint should not fail.

[ghstack-poisoned]
soulitzer added a commit that referenced this issue Mar 10, 2025
… during recompute"


Fixes #147995

TorchFunctionModeTLS is part of the autograd tls, but because .backward() itself is a leaf for TorchFunctionMode, the mode is disabled before we enter into the engine.  Conversely, since TorchDispatchMode traces through the .backward() python call, we don't actually need to manually stash/restore if the user keeps the same mode enabled.

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Apr 25, 2025
While we prefer not use monkey patching to dispatch SDPA, TorchFunctionMode is currently not compatible with selective activation checkpointing (#147995). This PR adds `TorchFunctionMode` to CP code and make it configurable.

Pull Request resolved: #147902
Approved by: https://github.com/XilunWu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: activation checkpointing Related to activation checkpointing module: __torch_function__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants
0