8000 [HOP] Rework Autograd DispatchKey for scan and map by bohnstingl · Pull Request #153336 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[HOP] Rework Autograd DispatchKey for scan and map #153336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from

Conversation

bohnstingl
Copy link
Collaborator
@bohnstingl bohnstingl commented May 10, 2025

This PR introduces the py_autograd_impl instead of the DispatchKey.Autograd for some HOPs.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Copy link
pytorch-bot bot commented May 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153336

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit e03e9ec with merge base 69a57d9 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 10, 2025
@bohnstingl bohnstingl changed the title [HOP] Reworke Autograd DispatchKey for HOPs [HOP] Rework Autograd DispatchKey for HOPs May 11, 2025
@bohnstingl bohnstingl changed the title [HOP] Rework Autograd DispatchKey for HOPs [HOP] Rework Autograd DispatchKey for scan and map May 12, 2025
Comment on lines 184 to 191
# outs = f(xs, *args)
# if pytree.tree_any(lambda elem: not isinstance(elem, torch.Tensor) if elem is not None else False, outs):
# raise RuntimeError(
# "Expect outputs of map only contains tensors or None. "
# f"Got types {[type(out) for out in pytree.tree_leaves(outs)]}."
# )
# return outs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using the new key @map_impl.py_autograd_impl, the map_dense gets triggered directly and thus create_fw_bw_graph is not used. Therefore, the check of the outputs in create_fw_bw_graph is also not performed and the check can either be done here or below (see comment)

8000

Comment on lines 269 to 277
outs_flatten = pytree.tree_leaves(pytrees)
if any(
# pytree.tree_map(lambda elem: not isinstance(elem, torch.Tensor), outs)
not isinstance(out, torch.Tensor) for out in outs_flatten if out is not None
):
raise RuntimeError(
"Expect outputs of map only contains tensors or None. "
f"Got types {[type(out) for out in outs_flatten]}."
)
Copy link
Collaborator Author
@bohnstingl bohnstingl May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or here.
Which option would be preferable?

@bohnstingl bohnstingl marked this pull request as ready for review May 12, 2025 18:37
@bohnstingl bohnstingl requested a review from zou3519 as a code owner May 12, 2025 18:37
@colesbury colesbury added the triaged This issue has been looked at 8000 a team member, and triaged and prioritized into an appropriate module label May 13, 2025
@@ -175,6 +175,14 @@ def run_flattened_map(f, flat_xs, flat_args):
def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs):
xs = pytree.tree_unflatten(flat_args[:num_xs], xs_tree_spec)
args = pytree.tree_unflatten(flat_args[num_xs:], args_tree_spec)
# outs = f(xs, *args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this is what I was mentioning above. In the map, we need to move the check of the output somewhere and I can see to places where it would work. Currently, I have put it on Lines L264-L274, but on Lines L178-L184 would be an alternative. Which one do you prefer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do it in dynamo? The MapHigherOrderVariable, to keep consistent with e.g. scan

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the check into dynamo and reused the _check_all_tensorvariable. However, map allows None values as outputs and thus I had to take this into account. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

@bohnstingl bohnstingl requested a review from ydwu4 June 3, 2025 23:22
@ydwu4
Copy link
Contributor
ydwu4 commented Jun 4, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 4, 2025
Copy link
pytorch-bot bot commented Jun 4, 2025

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jun 4, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
This PR introduces the `py_autograd_impl` instead of the `DispatchKey.Autograd` for some HOPs.

Pull Request resolved: pytorch#153336
Approved by: https://github.com/ydwu4
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/pytorch that referenced this pull request Jun 22, 2025
This PR introduces the `py_autograd_impl` instead of the `DispatchKey.Autograd` for some HOPs.

Pull Request resolved: pytorch#153336
Approved by: https://github.com/ydwu4
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/pytorch that referenced this pull request Jul 14, 2025
This PR introduces the `py_autograd_impl` instead of the `DispatchKey.Autograd` for some HOPs.

Pull Request resolved: pytorch#153336
Approved by: https://github.com/ydwu4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: dynamo open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0