8000 Checkpoint doesn't work with torch_function if torch_function change tensor metadata · Issue #147995 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Checkpoint doesn't work with torch_function if torch_function change tensor metadata #147995
Open
@fegin

Description

@fegin

🐛 Describe the bug

We are tying to use TorchFunctionMode to convert the input tensors of SDPA to DTensor (if they are not). Unfortunately this approach fails. Digging into the detail, this seems to be a fundamental limitation of checkpoint as checkpoint is not aware of __torch_function__. Below is a minimal repro which utilizes __torch_function__ to reshape the input tensors.

import torch
from torch.utils.checkpoint import checkpoint
from torch.overrides import TorchFunctionMode


def func(x, y) -> None:
    return torch.matmul(x, y)


class DistributeFunction(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        if kwargs is None:
            kwargs = {}

        if func != torch.matmul:
            return func(*args, **kwargs)

        a0 = args[0].reshape((-1, 128))
        a1 = args[1].reshape((128, -1))
        return func(a0, a1)


with DistributeFunction():
    a = torch.randn(64, 64)
    a.requires_grad = True
    out = checkpoint(func, a, a, use_reentrant=False)
    out.sum().backward()

Checkpoint complains metadata mismatch:

  File "/data/users/chienchin/mywork/pytorch/test.py", line 16, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/chienchin/mywork/pytorch/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/data/users/chienchin/mywork/pytorch/torch/autograd/__init__.py", line 353, in backward
    _engine_run_backward(
  File "/data/users/chienchin/mywork/pytorch/torch/autograd/graph.py", line 824, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass  File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 1129, in unpack_hook
    frame.check_recomputed_tensors_match(gid)
  File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 903, in check_recomputed_tensors_match
    raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 0:
saved metadata: {'shape': torch.Size([128, 32]), 'dtype': torch.float32, 'device': device(type='cpu')}
recomputed metadata: {'shape': torch.Size([64, 64]), 'dtype': torch.float32, 'device': device(type
6942
='cpu')}
tensor at position 1:
saved metadata: {'shape': torch.Size([32, 128]), 'dtype': torch.float32, 'device': device(type='cpu')}
recomputed metadata: {'shape': torch.Size([64, 64]), 'dtype': torch.float32, 'device': device(type='cpu')}

Is there any way to make this __torch_function__ work with Checkpoint?

Versions

nightly

cc @soulitzer @hameerabbasi @rgommers @ezyang

Metadata

Metadata

Assignees

Labels

module: __torch_function__module: activation checkpointingRelated to activation checkpointingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0