8000 [dynamo] Activation checkpointing tests erroring at runtime · Issue #127115 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] Activation checkpointing tests erroring at runtime #127115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
williamwen42 opened this issue May 24, 2024 · 0 comments
Closed

[dynamo] Activation checkpointing tests erroring at runtime #127115

williamwen42 opened this issue May 24, 2024 · 0 comments
Labels
module: activation checkpointing Related to activation checkpointing module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@williamwen42
Copy link
Member
williamwen42 commented May 24, 2024

Discovered with #126341.

Problematic tests:
PYTORCH_TEST_WITH_DYNAMO=1 pytest test/test_autograd.py::TestAutograd::test_checkpointing_without_reentrant_arbitrary_input_output and
PYTORCH_TEST_WITH_DYNAMO=1 pytest test/test_autograd.py::TestAutograd::test_checkpointing_without_reentrant_dataparallel (fails on CI; cannot repro latter locally).

Logs:

Traceback (most recent call last):
  File "/data/users/williamwen/pytorch2/torch/_dynamo/backends/debugging.py", line 33, in inner
    return gm(*args)
  File "/data/users/williamwen/pytorch2/torch/fx/graph_module.py", line 736, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/fx/graph_module.py", line 315, in __call__
    raise e
  File "/data/users/williamwen/pytorch2/torch/fx/graph_module.py", line 302, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/data/users/williamwen/pytorch2/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.1", line 7, in forward
    linear = torch._C._nn.linear(l_dict_input_tensor_, l_self_layer_weight, None);  l_dict_input_tensor_ = l_self_layer_weight = None
  File "/data/users/williamwen/pytorch2/torch/utils/checkpoint.py", line 1077, in pack_hook
    raise _StopRecomputationError
torch.utils.checkpoint._StopRecomputationError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/williamwen/py310-env/lib/python3.10/unittest/case.py", line 59, in testPartExecutor
    yield
  File "/data/users/williamwen/py310-env/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/data/users/williamwen/py310-env/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/data/users/williamwen/pytorch2/torch/testing/_internal/common_utils.py", line 2756, in wrapper
    method(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/test/test_autograd.py", line 7069, in test_checkpointing_without_reentrant_arbitrary_input_output
    out_checkpoint.backward()
  File "/data/users/williamwen/pytorch2/torch/_tensor.py", line 523, in backward
    torch.autograd.backward(
  File "/data/users/williamwen/pytorch2/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/data/users/williamwen/pytorch2/torch/autograd/graph.py", line 767, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/data/users/williamwen/pytorch2/torch/utils/checkpoint.py", line 1115, in unpack_hook
    frame.recompute_fn(*args)
  File "/data/users/williamwen/pytorch2/torch/utils/checkpoint.py", line 1401, in recompute_fn
    fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/test/test_autograd.py", line 7051, in forward
    def forward(self, dict_input):
  File "/data/users/williamwen/pytorch2/torch/_dynamo/eval_frame.py", line 548, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/backends/debugging.py", line 35, in inner
    raise torch._dynamo.exc.TorchDynamoException(
torch._dynamo.exc.TorchDynamoException: Unexpected exception when running generated GraphModule

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @ezyang @msaroufim @bdhirsh @anijain2305

@williamwen42 williamwen42 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo activation-checkpointing labels May 24, 2024
@soulitzer soulitzer added module: activation checkpointing Related to activation checkpointing and removed activation-checkpointing labels Mar 3, 2025
pytorchmergebot pushed a commit that referenced this issue May 16, 2025
…ger (#153300)

I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`.

This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail.

This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA.

FIXES #127115

Pull Request resolved: #153300
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: activation checkpointing Related to activation checkpointing module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants
0