8000 [dynamo][pytree][2/N] make CXX pytree traceable: `tree_flatten` / `tr… · pytorch/pytorch@e1816ba · GitHub
[go: up one dir, main page]

Skip to content

Commit e1816ba

Browse files
committed
[dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten / tree_unflatten
ghstack-source-id: 7e69b37 Pull Request resolved: #137398
1 parent 34c119e commit e1816ba

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed

test/dynamo/test_misc.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9878,6 +9878,48 @@ def fn(x):
98789878

98799879
self.assertEqual(actual, expected)
98809880

9881+
def test_cxx_pytree_tree_flatten(self):
9882+
import torch.utils._cxx_pytree as pytree
9883+
9884+
def fn(x):
9885+
tree = {
9886+
"a": [x, x - 1],
9887+
"b": x + 2,
9888+
"c": (x, 3.0),
9889+
"d": {"e": [2 * x, torch.ones(1, 1)]},
9890+
}
9891+
leaves = pytree.tree_flatten(tree)[0]
9892+
return sum(leaves)
9893+
9894+
x = torch.randn(3, 2)
9895+
expected = fn(x)
9896+
fn_opt = torch.compile(fullgraph=True)(fn)
9897+
actual = fn_opt(x)
9898+
9899+
self.assertEqual(actual, expected)
9900+
9901+
def test_cxx_pytree_tree_unflatten(self):
9902+
import torch.utils._cxx_pytree as pytree
9903+
9904+
def fn(x, y):
9905+
tree = {
9906+
"a": [x, x - 1],
9907+
"b": x + 2,
9908+
"c": (x, 3.0),
9909+
"d": {"e": [2 * x, torch.ones(1, 1)]},
9910+
}
9911+
treespec = pytree.tree_flatten(tree)[1]
9912+
leaves = [x - 1, y, x * y, 3.0, y - 2, torch.zeros(2, 2), 2 * y]
9913+
return pytree.tree_unflatten(leaves, treespec)
9914+
9915+
x = torch.randn(3, 2)
9916+
y = torch.randn(3, 2)
9917+
expected = fn(x, y)
9918+
fn_opt = torch.compile(fullgraph=True)(fn)
9919+
actual = fn_opt(x, y)
9920+
9921+
self.assertEqual(actual, expected)
9922+
98819923
def test_shape_env_no_recording(self):
98829924
main = ShapeEnv(should_record_events=False)
98839925

torch/_dynamo/polyfills/pytree.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
from dataclasses import dataclass, field
78
from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING
89

910
import torch.utils._pytree as python_pytree
@@ -12,6 +13,8 @@
1213

1314

1415
if TYPE_CHECKING:
16+
import builtins
17+
1518
from torch.utils._cxx_pytree import PyTree
1619

1720

@@ -69,3 +72,132 @@ def tree_leaves(
6972
return list(tree_iter(tree, is_leaf=is_leaf))
7073

7174
__all__ += ["tree_leaves"]
75+
76+
@dataclass(frozen=True)
77+
class PyTreeSpec:
78+
_children: list[PyTreeSpec]
79+
_type: builtins.type | None
80+
_metadata: Any
81+
_entries: tuple[Any] | None
82+
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
83+
84+
num_nodes: int = field(init=False)
85+
num_leaves: int = field(init=False)
86+
num_children: int = field(init=False)
87+
none_is_leaf: bool = field(init=False)
88+
namespace: str = field(init=False)
89+
90+
def __post_init__(self) -> None:
91+
if self._type is None:
92+
assert len(self._children) == 0
93+
assert self._metadata is None
94+
assert self._entries is None
95+
assert self._unflatten_func is None
96+
object.__setattr__(self, "num_nodes", 1)
97+
object.__setattr__(self, "num_leaves", 1)
98+
object.__setattr__(self, "num_children", 0)
99+
else:
100+
assert callable(self._unflatten_func)
101+
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
102+
num_leaves = sum(spec.num_leaves for spec in self._children)
103+
num_children = len(self._children)
104+
object.__setattr__(self, "num_nodes", num_nodes)
105+
object.__setattr__(self, "num_leaves", num_leaves)
106+
object.__setattr__(self, "num_children", num_children)
107+
108+
object.__setattr__(self, "none_is_leaf", True)
109+
object.__setattr__(self, "namespace", "torch")
110+
111+
@property
112+
def type(self) -> builtins.type | None:
113+
return self._type
114+
115+
def is_leaf(self) -> bool:
116+
return self.num_nodes == 1 and self.num_leaves == 1
117+
118+
def children(self) -> list[PyTreeSpec]:
119+
return self._children.copy()
120+
121+
def child(self, index: int) -> PyTreeSpec:
122+
return self._children[index]
123+
124+
def entries(self) -> list[Any]:
125+
if self._entries is None:
126+
return list(range(self.num_children))
127 F438 +
return list(self._entries)
128+
129+
def entry(self, index: int) -> Any:
130+
return self.entries()[index]
131+
132+
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
133+
if not isinstance(leaves, (list, tuple)):
134+
leaves = list(leaves)
135+
if len(leaves) != self.num_leaves:
136+
raise ValueError(
137+
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
138+
f"but the spec refers to a pytree that holds {self.num_leaves} "
139+
f"items ({self}).",
140+
)
141+
if self.is_leaf():
142+
return leaves[0]
143+
144+
# Recursively unflatten the children
145+
start = 0
146+
end = 0
147+
subtrees = []
148+
for subspec in self._children:
149+
end += subspec.num_leaves
150+
subtrees.append(subspec.unflatten(leaves[start:end]))
151+
start = end
152+
153+
assert callable(self._unflatten_func)
154+
return self._unflatten_func(self._metadata, subtrees)
155+
156+
leafspec = PyTreeSpec([], None, None, None, None)
157+
158+
@substitute_in_graph(cxx_pytree.tree_flatten, can_constant_fold_through=False) # type: ignore[arg-type]
159+
def tree_flatten(
160+
tree: PyTree,
161+
is_leaf: Callable[[PyTree], bool] | None = None,
162+
) -> tuple[list[Any], PyTreeSpec]:
163+
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
164+
if node is None or (is_leaf is not None and is_leaf(node)):
165+
leaves.append(node)
166+
return leafspec
167+
168+
node_type = type(node)
169+
if optree.register_pytree_node.get(node_type, namespace="torch") is None: # type: ignore[attr-defined]
170+
leaves.append(node)
171+
return leafspec
172+
173+
(
174+
children,
175+
metadata,
176+
entries,
177+
unflatten_func,
178+
) = optree.tree_flatten_one_level(
179+
node,
180+
is_leaf=is_leaf,
181+
none_is_leaf=True,
182+
namespace="torch",
183+
)
184+
185+
subspecs = [helper(child, leaves) for child in children]
186+
return PyTreeSpec(subspecs, node_type, metadata, entries, unflatten_func) # type: ignore[arg-type]
187+
188+
leaves: list[Any] = []
189+
treespec = helper(tree, leaves)
190+
return leaves, treespec
191+
192+
__all__ += ["tree_flatten"]
193+
194+
@substitute_in_graph(cxx_pytree.tree_unflatten, can_constant_fold_through=False) # type: ignore[arg-type]
195+
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
196+
if not isinstance(treespec, PyTreeSpec):
197+
raise TypeError(
198+
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
199+
f"TreeSpec but got item of type {type(treespec)}."
200+
)
201+
return treespec.unflatten(leaves)
202+
203+
__all__ += ["tree_unflatten"]

0 commit comments

Comments
 (0)
0