8000 [dynamo][pytree][2/N] make CXX pytree traceable: `tree_flatten` / `tree_unflatten` / `tree_structure` by XuehaiPan · Pull Request #137398 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten / tree_unflatten / tree_structure #137398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 64 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
1b1d91a
Update
XuehaiPan Oct 5, 2024
2ee4eb2
Update
XuehaiPan Oct 5, 2024
a090e3c
Update
XuehaiPan Oct 5, 2024
5bbc747
Update
XuehaiPan Oct 5, 2024
ccf7344
Update
XuehaiPan Oct 5, 2024
cd1f888
Update
XuehaiPan Oct 5, 2024
6a5a8a5
Update
XuehaiPan Oct 5, 2024
c498d86
Update
XuehaiPan Oct 5, 2024
9e2f0fa
Update
XuehaiPan Oct 6, 2024
91651fe
Update
XuehaiPan Oct 6, 2024
82ed62c
Update
XuehaiPan Oct 6, 2024
ab204c7
Update
XuehaiPan Oct 6, 2024
d9a23dc
Update
XuehaiPan Oct 6, 2024
1c045eb
Update
XuehaiPan Oct 6, 2024
d7dad64
Update
XuehaiPan Oct 6, 2024
2065856
Update
XuehaiPan Oct 6, 2024
45d3400
Update
XuehaiPan Oct 6, 2024
0e41f97
Update
XuehaiPan Oct 6, 2024
5315c2f
Update
XuehaiPan Oct 7, 2024
42c519b
Update
XuehaiPan Oct 7, 2024
2b9f79e
Update
XuehaiPan Oct 13, 2024
118e901
Update
XuehaiPan Oct 13, 2024
3422498
Update
XuehaiPan Oct 16, 2024
732057f
Update
XuehaiPan Oct 16, 2024
5e9c01c
Update
XuehaiPan Oct 16, 2024
16602a5
Update
XuehaiPan Oct 16, 2024
3c2bddb
Update
XuehaiPan Oct 16, 2024
8164ebf
Update
XuehaiPan Oct 16, 2024
cac8c23
Update
XuehaiPan Oct 16, 2024
8ba2581
Update
XuehaiPan Oct 16, 2024
1b97c3f
Update
XuehaiPan Oct 16, 2024
16a35da
Update
XuehaiPan Oct 17, 2024
348dd58
Update
XuehaiPan Oct 17, 2024
c230fa7
Update
XuehaiPan Oct 25, 2024
fd704f8
Update
XuehaiPan Oct 26, 2024
cfce0bb
Update
XuehaiPan Oct 29, 2024
0e0f86e
Update
XuehaiPan Oct 29, 2024
b539cb9
Update
XuehaiPan Oct 30, 2024
7f6e449
Update
XuehaiPan Oct 30, 2024
a8a4315
Update
XuehaiPan Nov 2, 2024
13c5af3
Update
XuehaiPan Nov 5, 2024
f0cea0b
Update
XuehaiPan Nov 11, 2024
29ffbbf
Update
XuehaiPan Nov 17, 2024
9ff3312
Update
XuehaiPan Nov 20, 2024
cb9abcc
Update
XuehaiPan Nov 20, 2024
32200cf
Update
XuehaiPan Nov 20, 2024
31a7f76
Update
XuehaiPan Nov 20, 2024
4477993
Update
XuehaiPan Nov 20, 2024
7ad1a4a
Update
XuehaiPan Nov 20, 2024
5820dcd
Update
XuehaiPan Nov 20, 2024
9d3affb
Update
XuehaiPan Nov 21, 2024
2089924
Update
XuehaiPan Nov 21, 2024
c6ee2f9
Update
XuehaiPan Nov 21, 2024
531cbe4
Update
XuehaiPan Nov 21, 2024
28a1225
Update
XuehaiPan Nov 21, 2024
4757b69
Update
XuehaiPan Nov 22, 2024
473c032
Update
XuehaiPan Nov 22, 2024
2c42b81
Update
XuehaiPan Nov 26, 2024
75fbfed
Update
XuehaiPan Nov 26, 2024
87ab4da
Update
XuehaiPan Nov 27, 2024
bb53b36
Update
XuehaiPan Dec 2, 2024
2a6463c
Update
XuehaiPan Dec 2, 2024
49cd5ce
Update
XuehaiPan Dec 7, 2024
17d9823
Update
XuehaiPan Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
XuehaiPan committed Oct 5, 2024
commit a090e3c3e599c0bd9e5f519890ea9db5e2bea93e
27 changes: 18 additions & 9 deletions torch/_dynamo/polyfills/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,35 @@ class PyTreeSpec:
_type: builtins.type | None
_metadata: Any
_entries: tuple[Any] | None
_unflatten_func: Callable
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None

num_nodes: int = field(init=False)
num_leaves: int = field(init=False)
num_children: int = field(init=False)
none_is_leaf: bool = field(init=False)
namespace: str = field(init=False)

def __post_init__(self):
def __post_init__(self) -> None:
if self._type is None:
assert len(self._children) == 0
assert self._metadata is None
assert self._entries is None
assert self._unflatten_func is None
object.__setattr__(self, "num_nodes", 1)
object.__setattr__(self, "num_leaves", 1)
object.__setattr__(self, "num_children", 0)
else:
assert callable(self._unflatten_func)
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_children = len(self._children)
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)

object.__setattr__(self, "none_is_leaf", True)
object.__setattr__(self, "namespace", "torch")

@property
def type(self) -> builtins.type | None:
return self._type
Expand Down Expand Up @@ -142,16 +150,17 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
subtrees.append(subspec.unflatten(leaves[start:end]))
start = end

assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)

leafspec = PyTreeSpec([], None, None, None, lambda x: x)
leafspec = PyTreeSpec([], None, None, None, None)

@substitute_in_graph(cxx_pytree.tree_flatten, can_constant_fold_through=False)
@substitute_in_graph(cxx_pytree.tree_flatten, can_constant_fold_through=False) # type: ignore[arg-type]
def tree_flatten(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[Any], PyTreeSpec]:
def helper(node, leaves):
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if node is None or (is_leaf is not None and is_leaf(node)):
leaves.append(node)
return leafspec
Expand All @@ -173,16 +182,16 @@ def helper(node, leaves):
namespace="torch",
)

treespecs = [helper(child, leaves) for child in children]
return PyTreeSpec(treespecs, node_type, metadata, entries, unflatten_func)
subspecs = [helper(child, leaves) for child in children]
return PyTreeSpec(subspecs, node_type, metadata, entries, unflatten_func) # typ 6AA5 e: ignore[arg-type]

leaves = []
leaves: list[Any] = []
treespec = helper(tree, leaves)
return leaves, treespec

__all__ += ["tree_flatten"]

@substitute_in_graph(cxx_pytree.tree_unflatten, can_constant_fold_through=False)
@substitute_in_graph(cxx_pytree.tree_unflatten, can_constant_fold_through=False) # type: ignore[arg-type]
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
if not isinstance(treespec, PyTreeSpec):
raise TypeError(
Expand Down
Loading
0