8000 [scan] Autograd with partial gradient support by bohnstingl · Pull Request #146285 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[scan] Autograd with partial gradient support #146285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 54 commits into from

Conversation

bohnstingl
Copy link
Collaborator
@bohnstingl bohnstingl commented Feb 3, 2025

This PR introduces the Autograd feature for scan with partial gradient support. It is a combination of the already opened PRs: #135631 and bohnstingl#4

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @ydwu4

@bohnstingl bohnstingl requested a review from zou3519 as a code owner February 3, 2025 00:36
Copy link
pytorch-bot bot commented Feb 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146285

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job

As of commit ebe8c22 with merge base 4273e5d (image):

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Feb 3, 2025
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 5, 2025
@ydwu4 ydwu4 self-requested a review February 7, 2025 00:51
Copy link
Contributor
@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're on the right direction. A big concern is readability, we should figure out a way to significantly reduce the complexity around all kinds of masking.

combine_fn,
False,
(
*fw_init,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do fw_init, fw_xs, fw_additional_inputs always set requires_grad = True? Can add a check before it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not necessarily. The all fw_init, fw_xs and fw_additional_inputs are required even if they don't require a gradient. This is because the joint_graph wouldn't work otherwise. In the revised version we return torch.zeros_like() for the elements that don't require gradients.

Moreover, is there anything from #142518 to consider here, or is there anything blocking CUDA graph support? I hope I replaced all the Nones with torch.zeros_like(), but I may have overlook something.

carried_g_additional_input = args[:num_additional_inputs]

g_c, g_xs = _extract_carry_and_out(
joint_graph(*args[num_additional_inputs:]), num_init
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, is the overall plan that 1. we trace fw_bw of combine_fn and get fw_bw_graph 2. we trace fw_bw_graph + gradient accumulation logic? Can put it down somewhere upfront and justify why this is necessary.

A big concern about this plan is that it's really hard to understand what's going on. I'm afraid no one is gonna be able to maintain this function in a few months. We should think about ways to improve this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to simplify this in a new version. The new flow is as follows:

  1. Create the forward and the joint graph of the combine_fn
  2. Retrace the wrapper of the forward graph that returns carries from all iterations, not just from the last iteration
  3. Obtain the gradients from the joint_graph and compute the gradient masks.
  4. Retrace the wrapper for the joint_graph using the masks.

Is this more clear? I have also added some more comments.

Comment on lines 471 to 482
outs = [
torch.zeros(
[num_elems] + list(e.size()),
dtype=e.dtype,
device=e.device,
)
for i, e in enumerate(dummy_out)
]
idxs = [
torch.ones_like(e, dtype=torch.int64).unsqueeze(0)
for i, e in enumerate(dummy_out)
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary for this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd think so. This is for storing the temporary outputs of scan

@zou3519 zou3519 removed their request for review February 12, 2025 15:32
@bohnstingl bohnstingl requested a review from ydwu4 February 19, 2025 07:57
@bohnstingl
< 67F4 div class=" timeline-comment-group js-minimizable-comment-group js-targetable-element TimelineItem-body my-0 " id="issuecomment-2776154125">
Copy link
Collaborator Author

@ydwu4 I tried to incorporate all your requests and I think the PR would be ready for another round.

@bohnstingl bohnstingl requested a review from ydwu4 April 3, 2025 15:20
Copy link
Contributor
@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly have some concerns about the organization of the documentation.

# The flipping back along the scan dimension is required to get the gradients in the right order for ``xs``
g_xs = [torch.flip(elem, [0]) for elem in g_xs]

# The gradients for additional inputs that are not tensors are replaced with None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention the partial grad handling notes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred to that note.

scan_fct = compile_mode_helper(scan, compile_mode)

x = torch.randn(3, 1, 2)
@parametrize("reverse", [False, True])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not need to parametrize device i feel

def test_scan_init_scanned_0(self, compile_mode):
scan_fct = compile_mode_helper(scan, compile_mode)

@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not need to parametrize device i feel

):
dim = 1
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
h1 = torch.randn(3, 7, device=device, requires_grad=autograd)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"init_carries_unequal_grad" meaning init and carry have different require_grad? or something else?

@bohnstingl bohnstingl requested a review from ydwu4 April 4, 2025 14:14
Copy link
Contributor
@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Left some comments for more clarity of the doc. Also need to fix the test failure.

@bohnstingl bohnstingl requested a review from ydwu4 April 11, 2025 07:43
@ydwu4
Copy link
Contributor
ydwu4 commented Apr 11, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 11, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-rocm-py3.10 / test (distributed, 1, 1, linux.rocm.gpu.4)

Details for Dev Infra team Raised by workflow job

@ydwu4
Copy link
Contributor
ydwu4 commented Apr 14, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / linux-focal-rocm-py3.10 / test (distributed, 1, 1, linux.rocm.gpu.4)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
This PR introduces the Autograd feature for scan with partial gradient support. It is a combination of the already opened PRs: pytorch#135631 and bohnstingl#4

Pull Request resolved: pytorch#146285
Approved by: https://github.com/ydwu4

Co-authored-by: Yidi Wu <yidi@meta.com>
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
This PR introduces the Autograd feature for scan with partial gradient support. It is a combination of the already opened PRs: pytorch#135631 and bohnstingl#4

Pull Request resolved: pytorch#146285
Approved by: https://github.com/ydwu4

Co-authored-by: Yidi Wu <yidi@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0