-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten
/ tree_unflatten
/ tree_structure
#137398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137398
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 17d9823 with merge base 7435f57 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ee_unflatten` / `tree_structure` ghstack-source-id: 7deed05 Pull Request resolved: pytorch#137398
…ee_unflatten` / `tree_structure` ghstack-source-id: ad8c86d Pull Request resolved: pytorch#137398
…ee_unflatten` / `tree_structure` ghstack-source-id: a47eb32 Pull Request resolved: pytorch#137398
…ee_unflatten` / `tree_structure` ghstack-source-id: d441e86 Pull Request resolved: pytorch#137398
…ee_unflatten` / `tree_structure` ghstack-source-id: 1025e14 Pull Request resolved: pytorch#137398
The existing C++ pytree and Python pytree share the same API and function signature. There are tests in our CI to ensure this. More, the pytree tests and dynamo tests are both tested against the Python pytree, C++ pytree, and the new public pytree API (a thin wrapper around the underlying pytree implementation). |
Hi @jansel, could you take a look at this PR and the next PR in the stack? Then the follow-ups can be reviewed and merged separately. |
…ap_` (#137399) Pull Request resolved: #137399 Approved by: https://github.com/jansel ghstack dependencies: #137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
…ee_unflatten` / `tree_structure` (pytorch#137398) Pull Request resolved: pytorch#137398 Approved by: https://github.com/jansel
…ap_` (pytorch#137399) Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
@dataclass(frozen=True) | F438||
class PyTreeSpec: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@XuehaiPan, if I'm understanding this correctly, you're adding a polyfill for cxx_pytree.tree_flatten such that it will return an instance of this PyTreeSpec class. I'm not sure this works if we are trying to return a PyTreeSpec from a torch.compile'd function: does it create an instance of this "class PyTreeSpec" object, or does it create an instance of torch.utils._cxx_pytree.TreeSpec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a compiled function, it returns torch._dynamo.polyfills.pytree.PyTreeSpec
. This class provides the exactly same interfaces with torch.utils._cxx_pytree.TreeSpec
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is wrong, see the following:
import torch
import torch.utils._cxx_pytree as pytree
@torch.compile(backend="eager", fullgraph=True)
def f(x, y):
vals, spec = pytree.tree_flatten(x)
return vals, spec, y.sin()
y = torch.randn(3)
x = [1, [2, [3, 4]]]
vals, spec, _ = f(x, y)
this_doesnt_work = pytree.tree_unflatten(vals, spec)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. It may cause problems if we only compile part of the program.
There is an in-progress polyfill infra for C++ classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened a PR to fix this:
Stack from ghstack (oldest at bottom):
torch.utils.pytree
by default #138056torch.func
#137884torch.utils.pytree
#137400dict
keys in insertion order in CXX pytree #130140tree_map
/tree_map_
#137399tree_flatten
/tree_unflatten
/tree_structure
#137398cc @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec