-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[scan] Autograd with partial gradient support #146285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146285
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled JobAs of commit ebe8c22 with merge base 4273e5d ( CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're on the right direction. A big concern is readability, we should figure out a way to significantly reduce the complexity around all kinds of masking.
torch/_higher_order_ops/scan.py
Outdated
combine_fn, | ||
False, | ||
( | ||
*fw_init, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do fw_init, fw_xs, fw_additional_inputs always set requires_grad = True? Can add a check before it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, not necessarily. The all fw_init
, fw_xs
and fw_additional_inputs
are required even if they don't require a gradient. This is because the joint_graph wouldn't work otherwise. In the revised version we return torch.zeros_like() for the elements that don't require gradients.
Moreover, is there anything from #142518 to consider here, or is there anything blocking CUDA graph support? I hope I replaced all the Nones with torch.zeros_like(), but I may have overlook something.
torch/_higher_order_ops/scan.py
Outdated
carried_g_additional_input = args[:num_additional_inputs] | ||
|
||
g_c, g_xs = _extract_carry_and_out( | ||
joint_graph(*args[num_additional_inputs:]), num_init |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, is the overall plan that 1. we trace fw_bw of combine_fn and get fw_bw_graph 2. we trace fw_bw_graph + gradient accumulation logic? Can put it down somewhere upfront and justify why this is necessary.
A big concern about this plan is that it's really hard to understand what's going on. I'm afraid no one is gonna be able to maintain this function in a few months. We should think about ways to improve this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to simplify this in a new version. The new flow is as follows:
- Create the forward and the joint graph of the
combine_fn
- Retrace the wrapper of the forward graph that returns carries from all iterations, not just from the last iteration
- Obtain the gradients from the joint_graph and compute the gradient masks.
- Retrace the wrapper for the joint_graph using the masks.
Is this more clear? I have also added some more comments.
torch/_higher_order_ops/scan.py
Outdated
outs = [ | ||
torch.zeros( | ||
[num_elems] + list(e.size()), | ||
dtype=e.dtype, | ||
device=e.device, | ||
) | ||
for i, e in enumerate(dummy_out) | ||
] | ||
idxs = [ | ||
torch.ones_like(e, dtype=torch.int64).unsqueeze(0) | ||
for i, e in enumerate(dummy_out) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary for this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd think so. This is for storing the temporary outputs of scan
…flip in the frontend
…o scan_autograd22
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@ydwu4 I tried to incorporate all your requests and I think the PR would be ready for another round. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly have some concerns about the organization of the documentation.
torch/_higher_order_ops/scan.py
Outdated
# The flipping back along the scan dimension is required to get the gradients in the right order for ``xs`` | ||
g_xs = [torch.flip(elem, [0]) for elem in g_xs] | ||
|
||
# The gradients for additional inputs that are not tensors are replaced with None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mention the partial grad handling notes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I referred to that note.
test/functorch/test_control_flow.py
Outdated
scan_fct = compile_mode_helper(scan, compile_mode) | ||
|
||
x = torch.randn(3, 1, 2) | ||
@parametrize("reverse", [False, True]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not need to parametrize device i feel
test/functorch/test_control_flow.py
Outdated
def test_scan_init_scanned_0(self, compile_mode): | ||
scan_fct = compile_mode_helper(scan, compile_mode) | ||
|
||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not need to parametrize device i feel
): | ||
dim = 1 | ||
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd) | ||
h1 = torch.randn(3, 7, device=device, requires_grad=autograd) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"init_carries_unequal_grad" meaning init and carry have different require_grad? or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Left some comments for more clarity of the doc. Also need to fix the test failure.
…nit, the xs or the additional_inputs is required
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-focal-rocm-py3.10 / test (distributed, 1, 1, linux.rocm.gpu.4) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: trunk / linux-focal-rocm-py3.10 / test (distributed, 1, 1, linux.rocm.gpu.4) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR introduces the Autograd feature for scan with partial gradient support. It is a combination of the already opened PRs: pytorch#135631 and bohnstingl#4 Pull Request resolved: pytorch#146285 Approved by: https://github.com/ydwu4 Co-authored-by: Yidi Wu <yidi@meta.com>
This PR introduces the Autograd feature for scan with partial gradient support. It is a combination of the already opened PRs: pytorch#135631 and bohnstingl#4 Pull Request resolved: pytorch#146285 Approved by: https://github.com/ydwu4 Co-authored-by: Yidi Wu <yidi@meta.com>
This PR introduces the Autograd feature for scan with partial gradient support. It is a combination of the already opened PRs: #135631 and bohnstingl#4
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @ydwu4