-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Selective Activation Checkpointing on custom autograd.Function #153334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@lvhoaa I can suggest two options Option 2: Refactor into nn.Module and use selective checkpointing See if this situation fits for you or not. I can emphasize more on these options if you need them. |
Some more options:
with apply_ac_policy("recompute_all"):
y = ...
z = y @ y
# Save the output of this matmul instead of recomputing it
tag_with_policy(z, CheckpointPolicy.MUST_SAVE)
... |
@soulitzer I'm using the ac-experimental repo. Originally, in my code flow, the whole thing is recomputed. For the part of the code that I don't want AC, I changed it to
This doesn't work (I put a print statement on my torch.autograd.Function and it still hits it in the model backward pass). When I tried
I have this error:
|
Just to double check, are you using torch.utils.checkpoint with ac-experimental APIs? (They are not compatible with one another). In order to use e.g.,
Do you have a longer stack trace or repro instructions? |
Just replaced torch.utils.checkpoint with apply_ac_policy("recompute_all"). Now have the error:
|
Thanks for the stack trace. I think what is happening is that the current code only works for versions of PyTorch >= 2.7 (due to pytree behavior changing). The relevant code: import torch
torch.__version__
from typing import NamedTuple
from torch.utils._pytree import tree_map_only
class NodeOutput(NamedTuple):
node: 'Node'
idx: int
# remove `_asdict` method to make it an opaque leaf (pytree checks `_fields`, `_make`, `_asdict`)
# Is there a better way to do this?
del NodeOutput._asdict
class Node():
pass
convert = lambda node_output: "hello"
no = NodeOutput(node=Node(), idx=1)
nos = (no,)
out = tree_map_only(NodeOutput, convert, nos)
print(out) On versions of PyTorch older than 2.7:
On newer versions of PyTorch:
|
Should be updated on the repo now soulitzer/ac-experimental@51cfc67 now. |
Thanks for the speed. Now I have this new error:
|
Sorry for the rough edges, this is relatively new - I made another update just now. soulitzer/ac-experimental@4f5cfa2 Let me know if that fixes things |
Thanks! That goes away. Now, another error comes up:
|
Hmm I'm not too sure what is happening here unfortunately. If you have some instructions to repro this (and/or more information about your setup), that would be very helpful. |
I made another update btw @lvhoaa , it's possible that this issue is addressed by that. Let me know if you have the chance to try it out again. |
Sorry for the late reply. It is fixed. Thanks! Aside from this, I think it will be helpful to explain the behavior of this new thing step-by-step. Like, what will be stored, what will be recomputed, and using which inputs. For example:
In this case, will Similar explanations needed for Documenting such clear step-by-step behavior will allow users to track down how it works, debug speed+memory, and also make its clearer on the development side. Best! |
Glad it works now! Thanks for the feedback, I will definitely need to take some time to add more documentation. |
Uh oh!
There was an error while loading. Please reload this page.
🚀 The feature, motivation and pitch
Selective Activation Checkpointing on custom autograd.Function?
Seems like custom autograd.Function is not seen as op. How to perform SAC on custom autograd.Function?
Alternatives
No response
Additional context
No response
cc @soulitzer
The text was updated successfully, but these errors were encountered: