8000 [dynamo][pytree][2/N] make CXX pytree traceable: `tree_flatten` / `tr… · pytorch/pytorch@7edeb10 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7edeb10

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten / tree_unflatten / tree_structure (#137398)
Pull Request resolved: #137398 Approved by: https://github.com/jansel
1 parent c85323c commit 7edeb10

File tree

5 files changed

+242
-41
lines changed

5 files changed

+242
-41
lines changed

test/dynamo/test_misc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10102,6 +10102,8 @@ def fn(x):
1010210102

1010310103
def test_pytree_tree_flatten_unflatten(self):
1010410104
implemtations = [("python", python_pytree)]
10105+
if cxx_pytree is not None:
10106+
implemtations.append(("cxx", cxx_pytree))
1010510107

1010610108
for name, module in implemtations:
1010710109
with self.subTest(f"pytree implement: {name}"):
@@ -10138,7 +10140,7 @@ def fn(x, y):
1013810140
torch.ones(3, 2),
1013910141
1,
1014010142
]
10141-
new_tree = module.tree_unflatten(leaves, treespec)
10143+
new_tree = module.tree_unflatten(new_leaves, treespec)
1014210144
return leaves, new_tree
1014310145

1014410146
x = torch.randn(3, 2)

test/test_pytree.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def __init__(self, x, y):
4343
self.y = y
4444

4545

46+
cxx_pytree.register_pytree_node(
47+
GlobalDummyType,
48+
lambda dummy: ([dummy.x, dummy.y], None),
49+
lambda xs, _: GlobalDummyType(*xs),
50+
serialized_type_name="GlobalDummyType",
51+
)
52+
53+
4654
class TestGenericPytree(TestCase):
4755
def test_aligned_public_apis(self):
4856
public_apis = py_pytree.__all__
@@ -1328,7 +1336,7 @@ def test_treespec_repr_dynamo(self):
13281336
_, spec = cxx_pytree.tree_flatten(pytree)
13291337
self.assertExpectedInline(
13301338
repr(spec),
1331-
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
1339+
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')",
13321340
)
13331341

13341342
@parametrize(
@@ -1383,12 +1391,6 @@ def test_pytree_serialize_namedtuple(self):
13831391
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
13841392

13851393
def test_pytree_custom_type_serialize(self):
1386-
cxx_pytree.register_pytree_node(
1387-
GlobalDummyType,
1388-
lambda dummy: ([dummy.x, dummy.y], None),
1389-
lambda xs, _: GlobalDummyType(*xs),
1390-
serialized_type_name="GlobalDummyType",
1391-
)
13921394
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
13931395
serialized_spec = cxx_pytree.treespec_dumps(spec)
13941396
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)

torch/_dynamo/polyfills/pytree.py

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,20 @@
44

55
from __future__ import annotations
66

7-
from typing import Any, Callable, Iterable, TYPE_CHECKING
7+
from dataclasses import dataclass, field
8+
from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING
9+
from typing_extensions import TypeIs
810

911
import torch.utils._pytree as python_pytree
12+
from torch.utils._pytree import BUILTIN_TYPES
1013

1114
from ..decorators import substitute_in_graph
1215

1316

1417
if TYPE_CHECKING:
18+
import builtins
19+
from typing_extensions import Self
20+
1521
from torch.utils._cxx_pytree import PyTree
1622

1723

@@ -95,3 +101,197 @@ def tree_leaves(
95101
return list(tree_iter(tree, is_leaf=is_leaf))
96102

97103
__all__ += ["tree_leaves"]
104+
105+
class _Asterisk(str):
106+
def __new__(cls) -> Self:
107+
return super().__new__(cls, "*")
108+
109+
def __repr__(self) -> str:
110+
return "*" # no quotes
111+
112+
_asterisk = _Asterisk()
113+
del _Asterisk
114+
115+
@dataclass(frozen=True)
116+
class PyTreeSpec:
117+
"""Analog for :class:`optree.PyTreeSpec` in Python."""
118+
119+
_children: tuple[PyTreeSpec, ...]
120+
_type: builtins.type | None
121+
_metadata: Any
122+
_entries: tuple[Any, ...]
123+
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
124+
125+
num_nodes: int = field(init=False)
126+
num_leaves: int = field(init=False)
127+
num_children: int = field(init=False)
128+
none_is_leaf: Literal[True] = field(init=False)
129+
namespace: Literal["torch"] = field(init=False)
130+
131+
def __post_init__(self) -> None:
132+
if self._type is None:
133+
assert len(self._children) == 0
134+
assert self._metadata is None
135+
assert self._entries == ()
136+
assert self._unflatten_func is None
137+
num_nodes = 1
138+
num_leaves = 1
139+
num_children = 0
140+
else:
141+
assert callable(self._unflatten_func)
142+
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
143+
num_leaves = sum(spec.num_leaves for spec in self._children)
144+
num_children = len(self._children)
145+
146+
object.__setattr__(self, "num_nodes", num_nodes)
147+
object.__setattr__(self, "num_leaves", num_leaves)
148+
object.__setattr__(self, "num_children", num_children)
149+
object.__setattr__(self, "none_is_leaf", True)
150+
object.__setattr__(self, "namespace", "torch")
151+
152+
def __repr__(self) -> str:
153+
def helper(treespec: PyTreeSpec) -> str:
154+
if treespec.is_leaf():
155+
assert treespec.type is None
156+
return _asterisk
157+
158+
assert treespec.type is not None
159+
assert callable(treespec._unflatten_func)
160+
children_representations = [
161+
helper(subspec) for subspec in treespec._children
162+
]
163+
if (
164+
treespec.type in BUILTIN_TYPES
165+
or optree.is_namedtuple_class(treespec.type)
166+
or optree.is_structseq_class(treespec.type)
167+
):
168+
return treespec._unflatten_func(
169+
treespec._metadata,
170+
children_representations,
171+
)
172+
return (
173+
f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
174+
f"[{', '.join(children_representations)}])"
175+
)
176+
177+
return (
178+
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
179+
)
180+
181+
def __len__(self) -> int:
182+
return self.num_leaves
183+
184+
@property
185+
def type(self) -> builtins.type | None:
186+
return self._type
187+
188+
def is_leaf(self) -> bool:
189+
return self.num_nodes == 1 and self.num_leaves == 1
190+
191+
def children(self) -> list[PyTreeSpec]:
192+
return list(self._children)
193+
194+
def child(self, index: int) -> PyTreeSpec:
195+
return self._children[index]
196+
197+
def entries(self) -> list[Any]:
198+
return list(self._entries)
199+
200+
def entry(self, index: int) -> Any:
201+
return self._entries[index]
202+
203+
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
204+
if not isinstance(leaves, (list, tuple)):
205+
leaves = list(leaves)
206+
if len(leaves) != self.num_leaves:
207+
raise ValueError(
208+
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
209+
f"but the spec refers to a pytree that holds {self.num_leaves} "
210+
f"items ({self}).",
211+
)
212+
if self.is_leaf():
213+
return leaves[0]
214+
215+
# Recursively unflatten the children
216+
start = 0
217+
end = 0
218+
subtrees = []
219+
for subspec in self._children:
220+
end += subspec.num_leaves
221+
subtrees.append(subspec.unflatten(leaves[start:end]))
222+
start = end
223+
224+
assert callable(self._unflatten_func)
225+
return self._unflatten_func(self._metadata, subtrees)
226+
227+
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
228+
229+
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
230+
return isinstance(obj, PyTreeSpec)
231+
232+
@substitute_in_graph( # type: ignore[arg-type]
233+
cxx_pytree.tree_flatten,
234+
# We need to disable constant folding here because we want the function to reference the
235+
# PyTreeSpec class defined above, not the one in the C++ module.
236+
can_constant_fold_through=False,
237+
)
238+
def tree_flatten(
239+
tree: PyTree,
240+
is_leaf: Callable[[PyTree], bool] | None = None,
241+
) -> tuple[list[Any], PyTreeSpec]:
242+
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
243+
if tree_is_leaf(node, is_leaf=is_leaf):
244+
leaves.append(node)
245+
return _LEAF_SPEC
246+
247+
(
248+
children,
249+
metadata,
250+
entries,
251+
unflatten_func,
252+
) = optree.tree_flatten_one_level(
253+
node,
254+
is_leaf=is_leaf,
255+
none_is_leaf=True,
256+
namespace="torch",
257+
)
258+
259+
# Recursively flatten the children
260+
subspecs = tuple(helper(child, leaves) for child in children)
261+
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
262+
263+
leaves: list[Any] = []
264+
treespec = helper(tree, leaves)
265+
return leaves, treespec
266+
267+
__all__ += ["tree_flatten"]
268+
269+
@substitute_in_graph( # type: ignore[arg-type]
270+
cxx_pytree.tree_structure,
271+
# We need to disable constant folding here because we want the function to reference the
272+
# PyTreeSpec class defined above, not the one in the C++ module.
273+
can_constant_fold_through=False,
274+
)
275+
def tree_structure(
276+
tree: PyTree,
277+
is_leaf: Callable[[PyTree], bool] | None = None,
278+
) -> PyTreeSpec:
279+
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
280+
281+
__all__ += ["tree_structure"]
282+
283+
@substitute_in_graph( # type: ignore[arg-type]
284+
cxx_pytree.tree_unflatten,
285+
# We need to disable constant folding here because we want the function to reference the
286+
# PyTreeSpec class defined above, not the one in the C++ module.
287+
can_constant_fold_through=False,
288+
)
289+
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
290+
if not _is_pytreespec_instance(treespec):
291+
raise TypeError(
292+
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
293+
f"PyTreeSpec but got item of type {type(treespec)}."
294+
)
295+
return treespec.unflatten(leaves)
296+
297+
__all__ += ["tree_unflatten"]

torch/utils/_cxx_pytree.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
TypeVar,
2828
Union,
2929
)
30-
from typing_extensions import deprecated
30+
from typing_extensions import deprecated, TypeIs
3131

3232
import optree
3333
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
@@ -240,6 +240,10 @@ def _private_register_pytree_node(
240240
)
241241

242242

243+
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
244+
return isinstance(obj, TreeSpec)
245+
246+
243247
def tree_is_leaf(
244248
tree: PyTree,
245249
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@@ -345,10 +349,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
345349
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
346350
``treespec``.
347351
"""
348-
if not isinstance(treespec, TreeSpec):
352+
if not _is_pytreespec_instance(treespec):
349353
raise TypeError(
350-
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
351-
f"TreeSpec but got item of type {type(treespec)}."
354+
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
355+
f"PyTreeSpec but got item of type {type(treespec)}."
352356
)
353357
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
354358

@@ -891,7 +895,7 @@ def _broadcast_to_and_flatten(
891895
treespec: TreeSpec,
892896
is_leaf: Optional[Callable[[PyTree], bool]] = None,
893897
) -> Optional[List[Any]]:
894-
assert isinstance(treespec, TreeSpec)
898+
assert _is_pytreespec_instance(treespec)
895899
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
896900
try:
897901
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
@@ -901,10 +905,10 @@ def _broadcast_to_and_flatten(
901905

902906
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
903907
"""Serialize a treespec to a JSON string."""
904-
if not isinstance(treespec, TreeSpec):
908+
if not _is_pytreespec_instance(treespec):
905909
raise TypeError(
906-
f"treespec_dumps(spec): Expected `spec` to be instance of "
907-
f"TreeSpec but got item of type {type(treespec)}."
910+
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
911+
f"PyTreeSpec but got item of type {type(treespec)}."
908912
)
909913

910914
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
@@ -938,7 +942,7 @@ def treespec_pprint(treespec: TreeSpec) -> str:
938942

939943
class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
940944
def __instancecheck__(self, instance: object) -> bool:
941-
return isinstance(instance, TreeSpec) and instance.is_leaf()
945+
return _is_pytreespec_instance(instance) and instance.is_leaf()
942946

943947

944948
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):

torch/utils/_pytree.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -869,37 +869,30 @@ def __repr__(self, indent: int = 0) -> str:
869869
_LEAF_SPEC = LeafSpec()
870870

871871

872-
def _tree_flatten_helper(
873-
tree: PyTree,
874-
leaves: List[Any],
875-
is_leaf: Optional[Callable[[PyTree], bool]] = None,
876-
) -> TreeSpec:
877-
if _is_leaf(tree, is_leaf=is_leaf):
878-
leaves.append(tree)
879-
return _LEAF_SPEC
880-
881-
node_type = _get_node_type(tree)
882-
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
883-
child_pytrees, context = flatten_fn(tree)
884-
885-
# Recursively flatten the children
886-
children_specs = [
887-
_tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees
888-
]
889-
890-
return TreeSpec(node_type, context, children_specs)
891-
892-
893872
def tree_flatten(
894873
tree: PyTree,
895874
is_leaf: Optional[Callable[[PyTree], bool]] = None,
896875
) -> Tuple[List[Any], TreeSpec]:
897876
"""Flattens a pytree into a list of values and a TreeSpec that can be used
898877
to reconstruct the pytree.
899878
"""
879+
880+
def helper(node: PyTree, leaves: List[Any]) -> TreeSpec:
881+
if _is_leaf(node, is_leaf=is_leaf):
882+
leaves.append(node)
883+
return _LEAF_SPEC
884+
885+
node_type = _get_node_type(node)
886+
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
887+
children, context = flatten_fn(node)
888+
889+
# Recursively flatten the children
890+
subspecs = [helper(child, leaves) for child in children]
891+
return TreeSpec(node_type, context, subspecs)
892+
900893
leaves: List[Any] = []
901-
spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf)
902-
return leaves, spec
894+
treespec = helper(tree, leaves)
895+
return leaves, treespec
903896

904897

905898
def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:

0 commit comments

Comments
 (0)
0