8000 [cudagraphs] Fix issue in collecting static_input_idxs by anijain2305 · Pull Request #152287 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[cudagraphs] Fix issue in collecting static_input_idxs #152287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

Copy link
pytorch-bot bot commented Apr 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152287

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8ccbcc7 with merge base 8e2e06b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

anijain2305 added a commit that referenced this pull request Apr 28, 2025
@anijain2305 anijain2305 added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Apr 28, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 28, 2025
Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this

related to #152275

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 28, 2025
Copy link
Contributor
@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

@@ -2135,9 +2135,9 @@ def inner_compile(
extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None,
):
if dynamic:
self.assertEqual(static_input_idxs, [0, 1, 2, 3, 4])
self.assertEqual(static_input_idxs, [2, 3, 4])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test looks strictly more correct now than previously. For sanity, this is the signature of the AOT graph in this test:

def forward(self, arg0_1: "Sym(s25)", arg1_1: "f32[s25][1]cpu", arg2_1: "f32[s25][1]cpu", arg3_1: "f32[s25][1]cpu", arg4_1: "Sym(s25)", arg5_1: "f32[s25][1]cpu"):

Where indices [2,3] correspond to the two static tensor inputs that mapped to the static TwoTensor subclass.

One thing that is wrong in this test though is that:

(1) in the dynamic shapes variant of this test, we have extra SymInt graph args that correspond to the symbolic sizes of the subclass

(2) we are marking those inputs as static indices as well, which is happening here: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/subclass_utils.py#L308

This seems wrong. It might turn out not cause too many problems, if inductor has logic to properly filter out SymInts from the "static input indices" list later (given that integers have no memory address and get burned into cudagraphs anyway). But we should probably fix it either way. cc @mlazos

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, I can take a look at the issue.

@@ -1054,7 +1055,11 @@ def _try_get_metadata_from_dynamo(
static_inputs_log.debug(
"Adding static input pos %s for source %s", pos, source_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we update this log call as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

related to #152275

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 28, 2025
related to #152275

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 28, 2025
@anijain2305 anijain2305 requested review from eellison and mlazos April 28, 2025 20:58
@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
8000
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@anijain2305 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Apr 29, 2025
…)"

This reverts commit 75a5646.

Reverted #152287 on behalf of https://github.com/wdvr due to causing ao failures - discussed with author ([comment](#152287 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Apr 29, 2025
@gante
Copy link
gante commented Apr 29, 2025

@anijain2305 thank you for the quick bugfix!

I've applied the changes in torch/_functorch/aot_autograd.py and torch/_inductor/compile_fx.py over the base torch 2.7, and I can confirm it solves the issue ✅

related to #152275

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 29, 2025
related to #152275

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Apr 29, 2025
# add on non param inputs
preserved_arg_indices.extend(range(len(flat_params), len(params)))
# is this necessary ?
fw_metadata.static_input_indices = static_indices_new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are the new changes to update the static_input_indices list under freezing. As @eellison pointed out, it might be a good idea to stash this info on the graph placeholders directly in the future, so we don't need to worry about updating this list after ever calling convention change in inductor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks these should be easy to merge in 2.7

@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@anijain2305
Copy link
Contributor Author

@pytorchbot merge -f "stuck merge"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@anijain2305
Copy link
Contributor Author

@pytorchbot cherry-pick --onto release/2.7 -c critical

pytorchbot pushed a commit that referenced this pull request May 4, 2025
related to #152275

Pull Request resolved: #152287
Approved by: https://github.com/bdhirsh, https://github.com/eellison

Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
(cherry picked from commit 4a63cab)
@pytorchbot
Copy link
Collaborator

Cherry picking #152287

The cherry pick PR is at #152768 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

@malfet malfet added this to the 2.7.1 milestone May 6, 2025
atalman pushed a commit that referenced this pull request May 6, 2025
[cudagraphs] Fix issue in collecting static_input_idxs (#152287)

related to #152275

Pull Request resolved: #152287
Approved by: https://github.com/bdhirsh, https://github.com/eellison


(cherry picked from commit 4a63cab)

Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants
0