-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[HOP] Rework Autograd DispatchKey for scan and map #153336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153336
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit e03e9ec with merge base 69a57d9 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
torch/_higher_order_ops/map.py
Outdated
# outs = f(xs, *args) | ||
# if pytree.tree_any(lambda elem: not isinstance(elem, torch.Tensor) if elem is not None else False, outs): | ||
# raise RuntimeError( | ||
# "Expect outputs of map only contains tensors or None. " | ||
# f"Got types {[type(out) for out in pytree.tree_leaves(outs)]}." | ||
# ) | ||
# return outs | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using the new key @map_impl.py_autograd_impl
, the map_dense
gets triggered directly and thus create_fw_bw_graph
is not used. Therefore, the check of the outputs in create_fw_bw_graph
is also not performed and the check can either be done here or below (see comment)
torch/_higher_order_ops/map.py
Outdated
outs_flatten = pytree.tree_leaves(pytrees) | ||
if any( | ||
# pytree.tree_map(lambda elem: not isinstance(elem, torch.Tensor), outs) | ||
not isinstance(out, torch.Tensor) for out in outs_flatten if out is not None | ||
): | ||
raise RuntimeError( | ||
"Expect outputs of map only contains tensors or None. " | ||
f"Got types {[type(out) for out in outs_flatten]}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or here.
Which option would be preferable?
torch/_higher_order_ops/map.py
Outdated
@@ -175,6 +175,14 @@ def run_flattened_map(f, flat_xs, flat_args): | |||
def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs): | |||
xs = pytree.tree_unflatten(flat_args[:num_xs], xs_tree_spec) | |||
args = pytree.tree_unflatten(flat_args[num_xs:], args_tree_spec) | |||
# outs = f(xs, *args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's going on here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do it in dynamo? The MapHigherOrderVariable, to keep consistent with e.g. scan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the check into dynamo and reused the _check_all_tensorvariable
. However, map allows None values as outputs and thus I had to take this into account. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good
@pytorchbot merge |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR introduces the `py_autograd_impl` instead of the `DispatchKey.Autograd` for some HOPs. Pull Request resolved: pytorch#153336 Approved by: https://github.com/ydwu4
This PR introduces the `py_autograd_impl` instead of the `DispatchKey.Autograd` for some HOPs. Pull Request resolved: pytorch#153336 Approved by: https://github.com/ydwu4
This PR introduces the `py_autograd_impl` instead of the `DispatchKey.Autograd` for some HOPs. Pull Request resolved: pytorch#153336 Approved by: https://github.com/ydwu4
This PR introduces the
py_autograd_impl
instead of theDispatchKey.Autograd
for some HOPs.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames