Open
Description
🐛 Describe the bug
We are tying to use TorchFunctionMode
to convert the input tensors of SDPA to DTensor (if they are not). Unfortunately this approach fails. Digging into the detail, this seems to be a fundamental limitation of checkpoint as checkpoint is not aware of __torch_function__
. Below is a minimal repro which utilizes __torch_function__
to reshape the input tensors.
import torch
from torch.utils.checkpoint import checkpoint
from torch.overrides import TorchFunctionMode
def func(x, y) -> None:
return torch.matmul(x, y)
class DistributeFunction(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if kwargs is None:
kwargs = {}
if func != torch.matmul:
return func(*args, **kwargs)
a0 = args[0].reshape((-1, 128))
a1 = args[1].reshape((128, -1))
return func(a0, a1)
with DistributeFunction():
a = torch.randn(64, 64)
a.requires_grad = True
out = checkpoint(func, a, a, use_reentrant=False)
out.sum().backward()
Checkpoint complains metadata mismatch:
File "/data/users/chienchin/mywork/pytorch/test.py", line 16, in __torch_function__
return func(*args, **kwargs)
File "/data/users/chienchin/mywork/pytorch/torch/_tensor.py", line 648, in backward
torch.autograd.backward(
File "/data/users/chienchin/mywork/pytorch/torch/autograd/__init__.py", line 353, in backward
_engine_run_backward(
File "/data/users/chienchin/mywork/pytorch/torch/autograd/graph.py", line 824, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 1129, in unpack_hook
frame.check_recomputed_tensors_match(gid)
File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 903, in check_recomputed_tensors_match
raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 0:
saved metadata: {'shape': torch.Size([128, 32]), 'dtype': torch.float32, 'device': device(type='cpu')}
recomputed metadata: {'shape': torch.Size([64, 64]), 'dtype': torch.float32, 'device': device(type
6942
='cpu')}
tensor at position 1:
saved metadata: {'shape': torch.Size([32, 128]), 'dtype': torch.float32, 'device': device(type='cpu')}
recomputed metadata: {'shape': torch.Size([64, 64]), 'dtype': torch.float32, 'device': device(type='cpu')}
Is there any way to make this __torch_function__
work with Checkpoint?
Versions
nightly