8000 Selective Activation Checkpointing on custom autograd.Function · Issue #153334 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Selective Activation Checkpointing on custom autograd.Function #153334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lvhoaa opened this issue May 10, 2025 · 14 comments
Open

Selective Activation Checkpointing on custom autograd.Function #153334

lvhoaa opened this issue May 10, 2025 · 14 comments
Labels
module: activation checkpointing Related to activation checkpointing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lvhoaa
Copy link
lvhoaa commented May 10, 2025

🚀 The feature, motivation and pitch

Selective Activation Checkpointing on custom autograd.Function?

Seems like custom autograd.Function is not seen as op. How to perform SAC on custom autograd.Function?

Alternatives

No response

Additional context

No response

cc @soulitzer

@AbhiLegend
Copy link

@lvhoaa I can suggest two options
Option 1: Use torch.utils.checkpoint.checkpoint manually
Wrap your autograd.Function usage with the checkpoint.

Option 2: Refactor into nn.Module and use selective checkpointing
If possible, refactor your logic into an nn.Module instead of autograd.Function, then apply torch.utils.checkpoint selectively on submodules:

See if this situation fits for you or not. I can emphasize more on these options if you need them.

@colesbury colesbury added module: activation checkpointing Related to activation checkpointing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 13, 2025
@soulitzer
C 8000 opy link
Contributor

Some more options:

  • Write a custom op instead, although custom ops currently are not as flexible wrt the types that are supported in inputs and outputs, this means that SAC would now be aware of your op.

  • Try https://github.com/soulitzer/ac-experimental . this newer version of AC supports tagging specific tensors to save or not save.

with apply_ac_policy("recompute_all"):
    y = ...
    z = y @ y
    # Save the output of this matmul instead of recomputing it
    tag_with_policy(z, CheckpointPolicy.MUST_SAVE)
    ...

@lvhoaa
Copy link
Author
lvhoaa commented May 14, 2025

@soulitzer I'm using the ac-experimental repo.

Originally, in my code flow, the whole thing is recomputed.

For the part of the code that I don't want AC, I changed it to

 with apply_ac_policy("must_save_all"):
            out = self.attend(
                q,
                k,
                v,
                attn_bias=attn_bias,
                mask=mask,
                windowed_mask=windowed_mask,
                memory_kv=self.memory_kv,
                **kwargs,
            )
            tag_with_policy(out, CheckpointPolicy.MUST_SAVE)

This doesn't work (I put a print statement on my torch.autograd.Function and it still hits it in the model backward pass).
Removing the with apply_ac_policy("must_save_all") or tag_with_policy(out, CheckpointPolicy.MUST_SAVE) does not fix the it either.

When I tried

with apply_ac_policy("recompute_all"):
        out = self.attend(
            q,
            k,
            v,
            attn_bias=attn_bias,
            mask=mask,
            windowed_mask=windowed_mask,
            memory_kv=self.memory_kv,
            **kwargs,
        )
        tag_with_policy(out, CheckpointPolicy.MUST_SAVE)

I have this error:

[rank0]: RuntimeError: aten::native_dropout() Expected a value of type 'Tensor' for argument 'input' but instead found type 'NodeOutput (aka NamedTuple(node, idx))'.
[rank0]: Position: 0
[rank0]: Value: NodeOutput(node=<ac_experimental.ac.Node object at 0x7f0264ee56d0>, idx=0)
[rank0]: Declaration: aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
[rank0]: Cast error details: Unable to cast NodeOutput(node=<ac_experimental.ac.Node object at 0x7f0264ee56d0>, idx=0) to Tensor

@soulitzer
Copy link
Contributor

I put a print statement on my torch.autograd.Function and it still hits it in the model backward pass

Just to double check, are you using torch.utils.checkpoint with ac-experimental APIs? (They are not compatible with one another). In order to use tag_with_policy(out, CheckpointPolicy.MUST_SAVE), you must also replace torch.utils.checkpoint with apply_ac_policy("recompute_all").

e.g.,

class Func(torch.autograd.Function):
    @staticmethod
     def forward(ctx, ...):
             out = ...
             tag_with_policy(out, CheckpointPolicy.MUST_SAVE)
             return out


# Replaces torch.utils.checkpoint
with apply_ac_policy("recompute_all"):
       ...
       Func.apply(x)

I have this error:

Do you have a longer stack trace or repro instructions?

@lvhoaa
Copy link
Author
lvhoaa commented May 14, 2025

Just replaced torch.utils.checkpoint with apply_ac_policy("recompute_all").

Now have the error:

File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
    ~~~~~~~~~~~~~~~~~~~~^
        tensors,
        ^^^^^^^^
    ...<5 lines>...
        accumulate_grad=True,
        ^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        t_outputs, *args, **kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
    )  # Calls into the C++ engine to run the backward pass
    ^
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 284, in unpack_hook
    x = unflatten_fn()
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 254, in <lambda>
    return lambda: unpack(packed)
                   ~~~~~~^^^^^^^^
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 277, in _unpack
    out = realize_and_decref(node_output)
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 108, in realize_and_decref
    return node_output.node.realize_and_decref(node_output.idx)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 144, in realize_and_decref
    raw_out = self.func(*new_args)
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
RuntimeError: aten::t() Expected a value of type 'Tensor' for argument 'self' but instead found type 'NodeOutput (aka NamedTuple(node, idx))'.
Position: 0
Value: NodeOutput(node=<ac_experimental.ac.Node object at 0x7ff64d839050>, idx=0)
Declaration: aten::t(Tensor(a) self) -> Tensor(a)
Cast error details: Unable to cast NodeOutput(node=<ac_experimental.ac.Node object at 0x7ff64d839050>, idx=0) to Tensor

@soulitzer
Copy link
Contributor
soulitzer commented May 14, 2025

Thanks for the stack trace. I think what is happening is that the current code only works for versions of PyTorch >= 2.7 (due to pytree behavior changing).
Let me update it so that it also supports older versions of PyTorch.

The relevant code:

import torch
torch.__version__
from typing import NamedTuple
from torch.utils._pytree import tree_map_only
class NodeOutput(NamedTuple):
    node: 'Node'
    idx: int
# remove `_asdict` method to make it an opaque leaf (pytree checks `_fields`, `_make`, `_asdict`)
# Is there a better way to do this?
del NodeOutput._asdict
class Node():
    pass
convert = lambda node_output: "hello"

no = NodeOutput(node=Node(), idx=1)
nos = (no,)
out = tree_map_only(NodeOutput, convert, nos)
print(out)

On versions of PyTorch older than 2.7:

(NodeOutput(node=<__main__.Node object at 0x7e1791a46a50>, idx=1),)

On newer versions of PyTorch:

('hello',)

@soulitzer
Copy link
Contributor
soulitzer commented May 14, 2025

Should be updated on the repo now soulitzer/ac-experimental@51cfc67 now.
Let me know if that works

@lvhoaa
Copy link
Author
lvhoaa commented May 14, 2025

Thanks for the speed.

Now I have this new error:

[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:     ~~~~~~~~~~~~~~~~~~~~^
[rank0]:         tensors,
[rank0]:         ^^^^^^^^
[rank0]:     ...<5 lines>...
[rank0]:         accumulate_grad=True,
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^
[rank0]:     )
[rank0]:     ^
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:         t_outputs, *args, **kwargs
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:     )  # Calls into the C++ engine to run the backward pass
[rank0]:     ^
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 288, in unpack_hook
[rank0]:     x = unflatten_fn()
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 258, in <lambda>
[rank0]:     return lambda: unpack(packed)
[rank0]:                    ~~~~~~^^^^^^^^
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 281, in _unpack
[rank0]:     out = realize_and_decref(node_output)
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 112, in realize_and_decref
[rank0]:     return node_output.node.realize_and_decref(node_output.idx)
[rank0]:            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 147, in realize_and_decref
[rank0]:     new_args = tree_map_only(NodeOutput, realize_and_decref, self.args)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1163, in tree_map_only
[rank0]:     return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 991, in tree_map
[rank0]:     return treespec.unflatten(map(func, *flat_args))
[rank0]:            ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 830, in unflatten
[rank0]:     leaves = list(leaves)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1109, in wrapped
[rank0]:     return func(x)
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 112, in realize_and_decref
[rank0]:     return node_output.node.realize_and_decref(node_output.idx)
[rank0]:            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 147, in realize_and_decref
[rank0]:     new_args = tree_map_only(NodeOutput, realize_and_decref, self.args)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1163, in tree_map_only
[rank0]:     return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 991, in tree_map
[rank0]:     return treespec.unflatten(map(func, *flat_args))
[rank0]:            ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 830, in unflatten
[rank0]:     leaves = list(leaves)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1109, in wrapped
[rank0]:     return func(x)
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 112, in realize_and_decref
[rank0]:     return node_output.node.realize_and_decref(node_output.idx)
[rank0]:            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 147, in realize_and_decref
[rank0]:     new_args = tree_map_only(NodeOutput, realize_and_decref, self.args)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1163, in tree_map_only
[rank0]:     return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 991, in tree_map
[rank0]:     return treespec.unflatten(map(func, *flat_args))
[rank0]:            ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 830, in unflatten
[rank0]:     leaves = list(leaves)
[rank0]:   File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1109, in wrapped
[rank0]:     return func(x)
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 112, in realize_and_decref
[rank0]:     return node_output.node.realize_and_decref(node_output.idx)
[rank0]:            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:   File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 150, in realize_and_decref
[rank0]:     out = self.out[idx]
[rank0]:           ~~~~~~~~^^^^^
[rank0]: IndexError: list index out of range

@soulitzer
Copy link
Contributor

Sorry for the rough edges, this is relatively new - I made another update just now. soulitzer/ac-experimental@4f5cfa2

Let me know if that fixes things

@lvhoaa
Copy link
Author
lvhoaa commented May 14, 2025

Thanks! That goes away. Now, another error comes up:

 File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
    ~~~~~~~~~~~~~~~~~~~~^
        tensors,
        ^^^^^^^^
    ...<5 lines>...
        accumulate_grad=True,
        ^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        t_outputs, *args, **kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
    )  # Calls into the C++ engine to run the backward pass
    ^
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 288, in unpack_hook
    x = unflatten_fn()
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 258, in <lambda>
    return lambda: unpack(packed)
                   ~~~~~~^^^^^^^^
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 281, in _unpack
    out = realize_and_decref(node_output)
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 112, in realize_and_decref
    return node_output.node.realize_and_decref(node_output.idx)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 147, in realize_and_decref
    new_args = tree_map_only(NodeOutput, realize_and_decref, self.args)
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1163, in tree_map_only
    return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 991, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 830, in unflatten
    leaves = list(leaves)
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1109, in wrapped
    return func(x)
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 112, in realize_and_decref
    return node_output.node.realize_and_decref(node_output.idx)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 147, in realize_and_decref
    new_args = tree_map_only(NodeOutput, realize_and_decref, self.args)
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1163, in tree_map_only
    return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 991, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 830, in unflatten
    leaves = list(leaves)
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/utils/_pytree.py", line 1109, in wrapped
    return func(x)
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 112, in realize_and_decref
    return node_output.node.realize_and_decref(node_output.idx)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/scratch/bcjw/hla/ac-experimental/ac_experimental/ac.py", line 148, in realize_and_decref
    raw_out = self.func(*new_args)
  File "/u/hla/.conda/envs/newenv/lib/python3.13/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

@soulitzer
Copy link
Contributor

Hmm I'm not too sure what is happening here unfortunately. If you have some instructions to repro this (and/or more information about your setup), that would be very helpful.
Thanks for your patience trying out the new AC so far!

@soulitzer
Copy link
Contributor

I made another update btw @lvhoaa , it's possible that this issue is addressed by that. Let me know if you have the chance to try it out again.

soulitzer/ac-experimental@8003f65

@lvhoaa
Copy link
Author
lvhoaa commented May 26, 2025

Sorry for the late reply. It is fixed. Thanks!

Aside from this, I think it will be helpful to explain the behavior of this new thing step-by-step. Like, what will be stored, what will be recomputed, and using which inputs.

For example:

with apply_ac_policy("recompute_all"):
    outA = functionA(input)
    with apply_ac_policy("must_save_all"):
        outB = functionB(outA)
    outC = functionC(outB)

In this case, will input be stored? Will outA (which acts as input to saved functionB) be stored? Given outB is stored, for functionC, will the recomputation "restart" from this stored outB or it will just "restart" from input?

Similar explanations needed for tag_with_policy.

Documenting such clear step-by-step behavior will allow users to track down how it works, debug speed+memory, and also make its clearer on the development side.

Best!

@soulitzer
Copy link
Contributor

Glad it works now! Thanks for the feedback, I will definitely need to take some time to add more documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: activation checkpointing Related to activation checkpointing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0