10000 [pytree] add another simplified pytree module `torch.pytree` · pytorch/pytorch@66e28d8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 66e28d8

Browse files
committed
[pytree] add another simplified pytree module torch.pytree
Differences between `torch.pytree` and `torch.utils.pytree`: 1. APIs in `torch.utils.pytree` have a `tree_` prefix: ```python leaves, treespec = torch.utils.pytree.tree_flatten(tree) new_tree = torch.utils.pytree.tree_map(func, tree) leaevs, treespec = torch.pytree.flatten(tree) new_tree = torch.pytree.map(func, tree) ``` 2. The argument order of `unflatten` is reversed for better `functools.partial` support: ```python tree = torch.utils.pytree.tree_unflatten(leaves, treespec) tree = torch.pytree.unflatten(treespec, leaves) unflatten_fn = functools.partial(torch.pytree.unflatten, treespec) tree1 = unflatten_fn(leaves1) tree2 = unflatten_fn(leaves2) ``` This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`. ghstack-source-id: 3d16082 Pull Request resolved: #148180
1 parent 4c3a694 commit 66e28d8

File tree

6 files changed

+142
-1
lines changed

6 files changed

+142
-1
lines changed

CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,5 @@ torch/backends/cudnn/ @eqy @syed-ahmed
193193
/torch/utils/_pytree.py @XuehaiPan
194194
/torch/utils/_cxx_pytree.py @XuehaiPan
195195
/torch/utils/pytree/ @XuehaiPan
196+
/torch/pytree.py @XuehaiPan
196197
/torch/_dynamo/polyfills/pytree.py @XuehaiPan

docs/source/pytorch-api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ torch.signal <signal>
5050
torch.special <special>
5151
torch.overrides
5252
torch.package <package>
53+
torch.pytree <pytree>
5354
profiler
5455
nn.init
5556
nn.attention
@@ -67,7 +68,7 @@ sparse
6768
storage
6869
torch.testing <testing>
6970
torch.utils <utils>
70-
torch.utils.pytree <pytree>
71+
torch.utils.pytree <torch.utils.pytree>
7172
torch.utils.benchmark <benchmark_utils>
7273
torch.utils.bottleneck <bottleneck>
7374
torch.utils.checkpoint <checkpoint>

docs/source/pytree.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
torch.pytree
2+
============
3+
4+
.. currentmodule:: torch.pytree
5+
6+
.. automodule:: torch.pytree
7+
:members:

test/allowlist_for_publicAPI.json

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,28 @@
694694
"kineto_available",
695695
"record_function"
696696
],
697+
"torch.pytree": [
698+
"PyTreeSpec",
699+
"register_node",
700+
"all",
701+
"all_only",
702+
"any",
703+
"any_only",
704+
"flatten",
705+
"iter",
706+
"leaves",
707+
"map",
708+
"map_",
709+
"map_only",
710+
"map_only_",
711+
"structure",
712+
"is_namedtuple",
713+
"is_namedtuple_class",
714+
"is_namedtuple_instance",
715+
"is_structseq",
716+
"is_structseq_class",
717+
"is_structseq_instance"
718+
],
697719
"torch.quantization": [
698720
"ABC",
699721
"DeQuantStub",

torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2692,6 +2692,7 @@ def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
26922692
_inductor as _inductor,
26932693
_subclasses as _subclasses,
26942694
onnx as onnx,
2695+
pytree as pytree,
26952696
)
26962697

26972698
else:
@@ -2701,6 +2702,7 @@ def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
27012702
"_export",
27022703
# ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
27032704
"onnx",
2705+
"pytree",
27042706
}
27052707

27062708
def __getattr__(name):

torch/pytree.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Owner(s): ["module: pytree"]
2+
3+
"""
4+
Contains utility functions for working with nested python data structures.
5+
6+
A *pytree* is Python nested data structure. It is a tree in the sense that
7+
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
8+
Python values. Furthermore, a pytree should not contain reference cycles.
9+
10+
pytrees are useful for working with nested collections of Tensors. For example,
11+
one can use `map` to map a function over all Tensors inside some nested
12+
collection of Tensors and `leaves` to get a flat list of all Tensors
13+
inside some nested collection. pytrees are helpful for implementing nested
14+
collection support for PyTorch APIs.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
from typing import Any as _Any, TYPE_CHECKING as _TYPE_CHECKING
20+
21+
from torch.utils.pytree import (
22+
is_namedtuple,
23+
is_namedtuple_class,
24+
is_namedtuple_instance,
25+
is_structseq,
26+
is_structseq_class,
27+
is_structseq_instance,
28+
PyTree,
29+
PyTreeSpec,
30+
register_pytree_node as register_node,
31+
tree_all as all,
32+
tree_all_only as all_only,
33+
tree_any as any,
34+
tree_any_only as any_only,
35+
tree_flatten as flatten,
36+
tree_iter as iter,
37+
tree_leaves as leaves,
38+
tree_map as map,
39+
tree_map_ as map_,
40+
tree_map_only as map_only,
41+
tree_map_only_ as map_only_ F438 ,
42+
tree_structure as structure,
43+
tree_unflatten as _tree_unflatten,
44+
)
45+
46+
47+
if _TYPE_CHECKING:
48+
from collections.abc import Iterable
49+
50+
51+
__all__ = [
52+
"PyTreeSpec",
53+
"register_node",
54+
"flatten",
55+
"unflatten",
56+
"iter",
57+
"leaves",
58+
"structure",
59+
"map",
60+
"map_",
61+
"map_only",
62+
"map_only_",
63+
"all",
64+
"any",
65+
"all_only",
66+
"any_only",
67+
"is_namedtuple",
68+
"is_namedtuple_class",
69+
"is_namedtuple_instance",
70+
"is_structseq",
71+
"is_structseq_class",
72+
"is_structseq_instance",
73+
]
74+
75+
76+
def unflatten(treespec: PyTreeSpec, leaves: Iterable[_Any]) -> PyTree:
77+
"""Reconstruct a pytree from the treespec and the leaves.
78+
79+
The inverse of :func:`flatten`.
80+
81+
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
82+
>>> leaves, treespec = torch.pytree.flatten(tree)
83+
>>> tree == torch.pytree.unflatten(treespec, leaves)
84+
True
85+
86+
.. warning::
87+
88+
This function has a different signature than :func:`torch.utils.pytree.tree_unflatten`.
89+
The ``treespec`` argument comes first to have a better :class:`functools.partial` support:
90+
91+
.. code-block:: python
92+
93+
import functools
94+
95+
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
96+
tree1 = unflatten_fn(leaves1)
97+
tree2 = unflatten_fn(leaves2)
98+
99+
Args:
100+
treespec (PyTreeSpec): The treespec to reconstruct.
101+
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
102+
number of leaves of the treespec.
103+
104+
Returns:
105+
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
106+
``treespec``.
107+
"""
108+
return _tree_unflatten(leaves, treespec)

0 commit comments

Comments
 (0)
0