8000 [dynamo][pytree][2/N] make CXX pytree traceable: `tree_flatten` / `tree_unflatten` / `tree_structure` by XuehaiPan · Pull Request #137398 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten / tree_unflatten / tree_structure #137398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 64 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
1b1d91a
Update
XuehaiPan Oct 5, 2024
2ee4eb2
Update
XuehaiPan Oct 5, 2024
a090e3c
Update
XuehaiPan Oct 5, 2024
5bbc747
Update
XuehaiPan Oct 5, 2024
ccf7344
Update
XuehaiPan Oct 5, 2024
cd1f888
Update
XuehaiPan Oct 5, 2024
6a5a8a5
Update
XuehaiPan Oct 5, 2024
c498d86
Update
XuehaiPan Oct 5, 2024
9e2f0fa
Update
XuehaiPan Oct 6, 2024
91651fe
Update
XuehaiPan Oct 6, 2024
82ed62c
Update
XuehaiPan Oct 6, 2024
ab204c7
Update
XuehaiPan Oct 6, 2024
d9a23dc
Update
XuehaiPan Oct 6, 2024
1c045eb
Update
XuehaiPan Oct 6, 2024
d7dad64
Update
XuehaiPan Oct 6, 2024
2065856
Update
XuehaiPan Oct 6, 2024
45d3400
Update
XuehaiPan Oct 6, 2024
0e41f97
Update
XuehaiPan Oct 6, 2024
5315c2f
Update
XuehaiPan Oct 7, 2024
42c519b
Update
XuehaiPan Oct 7, 2024
2b9f79e
Update
XuehaiPan Oct 13, 2024
118e901
Update
XuehaiPan Oct 13, 2024
3422498
Update
XuehaiPan Oct 16, 2024
732057f
Update
XuehaiPan Oct 16, 2024
5e9c01c
Update
XuehaiPan Oct 16, 2024
16602a5
Update
XuehaiPan Oct 16, 2024
3c2bddb
Update
XuehaiPan Oct 16, 2024
8164ebf
Update
XuehaiPan Oct 16, 2024
cac8c23
Update
XuehaiPan Oct 16, 2024
8ba2581
Update
XuehaiPan Oct 16, 2024
1b97c3f
Update
XuehaiPan Oct 16, 2024
16a35da
Update
XuehaiPan Oct 17, 2024
348dd58
Update
XuehaiPan Oct 17, 2024
c230fa7
Update
XuehaiPan Oct 25, 2024
fd704f8
Update
XuehaiPan Oct 26, 2024
cfce0bb
Update
XuehaiPan Oct 29, 2024
0e0f86e
Update
XuehaiPan Oct 29, 2024
b539cb9
Update
XuehaiPan Oct 30, 2024
7f6e449
Update
XuehaiPan Oct 30, 2024
a8a4315
Update
XuehaiPan Nov 2, 2024
13c5af3
Update
XuehaiPan Nov 5, 2024
f0cea0b
Update
XuehaiPan Nov 11, 2024
29ffbbf
Update
XuehaiPan Nov 17, 2024
9ff3312
Update
XuehaiPan Nov 20, 2024
cb9abcc
Update
XuehaiPan Nov 20, 2024
32200cf
Update
XuehaiPan Nov 20, 2024
31a7f76
Update
XuehaiPan Nov 20, 2024
4477993
Update
XuehaiPan Nov 20, 2024
7ad1a4a
Update
XuehaiPan Nov 20, 2024
5820dcd
Update
XuehaiPan Nov 20, 2024
9d3affb
Update
XuehaiPan Nov 21, 2024
2089924
Update
XuehaiPan Nov 21, 2024
c6ee2f9
Update
XuehaiPan Nov 21, 2024
531cbe4
Update
XuehaiPan Nov 21, 2024
28a1225
Update
XuehaiPan Nov 21, 2024
4757b69
Update
XuehaiPan Nov 22, 2024
473c032
Update
XuehaiPan Nov 22, 2024
2c42b81
Update
XuehaiPan Nov 26, 2024
75fbfed
Update
XuehaiPan Nov 26, 2024
87ab4da
Update
XuehaiPan Nov 27, 2024
bb53b36
Update
XuehaiPan Dec 2, 2024
2a6463c
Update
XuehaiPan Dec 2, 2024
49cd5ce
Update
XuehaiPan Dec 7, 2024
17d9823
Update
XuehaiPan Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10263,6 +10263,8 @@ def fn(x):

def test_pytree_tree_flatten_unflatten(self):
implemtations = [("python", python_pytree)]
if cxx_pytree is not None:
implemtations.append(("cxx", cxx_pytree))

for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"):
Expand Down Expand Up @@ -10299,7 +10301,7 @@ def fn(x, y):
torch.ones(3, 2),
1,
]
new_tree = module.tree_unflatten(leaves, treespec)
new_tree = module.tree_unflatten(new_leaves, treespec)
return leaves, new_tree

x = torch.randn(3, 2)
Expand Down
16 changes: 9 additions & 7 deletions test/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def __init__(self, x, y):
self.y = y


cxx_pytree.register_pytree_node(
GlobalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: GlobalDummyType(*xs),
serialized_type_name="GlobalDummyType",
)


class TestGenericPytree(TestCase):
def test_aligned_public_apis(self):
public_apis = py_pytree.__all__
Expand Down Expand Up @@ -1328,7 +1336,7 @@ def test_treespec_repr_dynamo(self):
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')",
)

@parametrize(
Expand Down Expand Up @@ -1383,12 +1391,6 @@ def test_pytree_serialize_namedtuple(self):
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)

def test_pytree_custom_type_serialize(self):
cxx_pytree.register_pytree_node(
GlobalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: GlobalDummyType(*xs),
serialized_type_name="GlobalDummyType",
)
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
Expand Down
202 changes: 201 additions & 1 deletion torch/_dynamo/polyfills/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@

from __future__ import annotations

from typing import Any, Callable, Iterable, TYPE_CHECKING
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING
from typing_extensions import TypeIs

import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES

from ..decorators import substitute_in_graph


if TYPE_CHECKING:
import builtins
from typing_extensions import Self

from torch.utils._cxx_pytree import PyTree


Expand Down Expand Up @@ -95,3 +101,197 @@ def tree_leaves(
return list(tree_iter(tree, is_leaf=is_leaf))

__all__ += ["tree_leaves"]

class _Asterisk(str):
def __new__(cls) -> Self:
return super().__new__(cls, "*")

def __repr__(self) -> str:
return "*" # no quotes

_asterisk = _Asterisk()
del _Asterisk

@dataclass(frozen=True)
class PyTreeSpec:
Comment on lines +115 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XuehaiPan, if I'm understanding this correctly, you're adding a polyfill for cxx_pytree.tree_flatten such that it will return an instance of this PyTreeSpec class. I'm not sure this works if we are trying to return a PyTreeSpec from a torch.compile'd function: does it create an instance of this "class PyTreeSpec" object, or does it create an instance of torch.utils._cxx_pytree.TreeSpec?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a compiled function, it returns torch._dynamo.polyfills.pytree.PyTreeSpec. This class provides the exactly same interfaces with torch.utils._cxx_pytree.TreeSpec.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is wrong, see the following:

import torch
import torch.utils._cxx_pytree as pytree

@torch.compile(backend="eager", fullgraph=True)
def f(x, y):
    vals, spec = pytree.tree_flatten(x)
    return vals, spec, y.sin()

y = torch.randn(3)
x = [1, [2, [3, 4]]]
vals, spec, _ = f(x, y)
this_doesnt_work = pytree.tree_unflatten(vals, spec)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It may cause problems if we only compile part of the program.

There is an in-progress polyfill infra for C++ classes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Analog for :class:`optree.PyTreeSpec` in Python."""

_children: tuple[PyTreeSpec, ...]
_type: builtins.type | None
_metadata: Any
_entries: tuple[Any, ...]
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None

num_nodes: int = field(init=False)
num_leaves: int = field(init=False)
num_children: int = field(init=False)
none_is_leaf: Literal[True] = field(init=False)
namespace: Literal["torch"] = field(init=False)

def __post_init__(self) -> None:
if self._type is None:
assert len(self._children) == 0
assert self._metadata is None
assert self._entries == ()
assert self._unflatten_func is None
num_nodes = 1
num_leaves = 1
num_children = 0
else:
assert callable(self._unflatten_func)
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_children = len(self._children)

object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
object.__setattr__(self, "none_is_leaf", True)
object.__setattr__(self, "namespace", "torch")

def __repr__(self) -> str:
def helper(treespec: PyTreeSpec) -> str:
if treespec.is_leaf():
assert treespec.type is None
return _asterisk

assert treespec.type is not None
assert callable(treespec._unflatten_func)
children_representations = [
helper(subspec) for subspec in treespec._children
]
if (
treespec.type in BUILTIN_TYPES
or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type)
):
return treespec._unflatten_func(
treespec._metadata,
children_representations,
)
return (
f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
f"[{', '.join(children_representations)}])"
)

return (
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
)

def __len__(self) -> int:
return self.num_leaves

@property
def type(self) -> builtins.type | None:
return self._type

def is_leaf(self) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1

def children(self) -> list[PyTreeSpec]:
return list(self._children)

def child(self, index: int) -> PyTreeSpec:
return self._children[index]

def entries(self) -> list[Any]:
return list(self._entries)

def entry(self, index: int) -> Any:
return self._entries[index]

def unflatten(self, leaves: Iterable[Any]) -> PyTree:
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
if len(leaves) != self.num_leaves:
raise ValueError(
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
f"but the spec refers to a pytree that holds {self.num_leaves} "
f"items ({self}).",
)
if self.is_leaf():
return leaves[0]

# Recursively unflatten the children
start = 0
end = 0
subtrees = []
for subspec in self._children:
end += subspec.num_leaves
subtrees.append(subspec.unflatten(leaves[start:end]))
start = end

assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)

_LEAF_SPEC = PyTreeSpec((), None, None, (), None)

def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)

@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[Any], PyTreeSpec]:
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if tree_is_leaf(node, is_leaf=is_leaf):
leaves.append(node)
return _LEAF_SPEC

(
children,
metadata,
entries,
unflatten_func,
) = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)

# Recursively flatten the children
subspecs = tuple(helper(child, leaves) for child in children)
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]

leaves: list[Any] = []
treespec = helper(tree, leaves)
return leaves, treespec

__all__ += ["tree_flatten"]

@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_structure,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_structure(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTreeSpec:
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]

__all__ += ["tree_structure"]

@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)

__all__ += ["tree_unflatten"]
22 changes: 13 additions & 9 deletions torch/utils/_cxx_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TypeVar,
Union,
)
from typing_extensions import deprecated
from typing_extensions import deprecated, TypeIs

import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
Expand Down Expand Up @@ -240,6 +240,10 @@ def _private_register_pytree_node(
)


def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)


def tree_is_leaf(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
Expand Down Expand Up @@ -345,10 +349,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
if not isinstance(treespec, TreeSpec):
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}."
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]

Expand Down Expand Up @@ -891,7 +895,7 @@ def _broadcast_to_and_flatten(
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[List[Any]]:
assert isinstance(treespec, TreeSpec)
assert _is_pytreespec_instance(treespec)
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
Expand All @@ -901,10 +905,10 @@ def _broadcast_to_and_flatten(

def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
"""Serialize a treespec to a JSON string."""
if not isinstance(treespec, TreeSpec):
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"treespec_dumps(spec): Expected `spec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}."
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)

dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
Expand Down Expand Up @@ -938,7 +942,7 @@ def treespec_pprint(treespec: TreeSpec) -> str:

class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
def __instancecheck__(self, instance: object) -> bool:
return isinstance(instance, TreeSpec) and instance.is_leaf()
return _is_pytreespec_instance(instance) and instance.is_leaf()


class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
Expand Down
39 changes: 16 additions & 23 deletions torch/utils/_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,37 +869,30 @@ def __repr__(self, indent: int = 0) -> str:
_LEAF_SPEC = LeafSpec()


def _tree_flatten_helper(
tree: PyTree,
leaves: List[Any],
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> TreeSpec:
if _is_leaf(tree, is_leaf=is_leaf):
leaves.append(tree)
return _LEAF_SPEC

node_type = _get_node_type(tree)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)

# Recursively flatten the children
children_specs = [
_tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees
]

return TreeSpec(node_type, context, children_specs)


def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values and a TreeSpec that can be used
to reconstruct the pytree.
"""

def helper(node: PyTree, leaves: List[Any]) -> TreeSpec:
if _is_leaf(node, is_leaf=is_leaf):
leaves.append(node)
return _LEAF_SPEC

node_type = _get_node_type(node)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(node)

# Recursively flatten the children
subspecs = [helper(child, leaves) for child in children]
return TreeSpec(node_type, context, subspecs)

leaves: List[Any] = []
spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf)
return leaves, spec
treespec = helper(tree, leaves)
return leaves, treespec


def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
Expand Down
Loading
0