8000 [pytree][4/N] make `torch.utils.pytree` as public API · pytorch/pytorch@0c92aea · GitHub
[go: up one dir, main page]

Skip to content

Commit 0c92aea

Browse files
committed
[pytree][4/N] make torch.utils.pytree as public API
Populate APIs from `torch.utils._pytree` (default) or `torch.utils._cxx_pytree` to a new public module `torch.utils.pytree`. There is a environment varaible `PYTORCH_USE_CXX_PYTREE` (disabled by default) to control this. Since the CXX pytree is now Dynamo traceable, the users can change the underlining pytree implementation by flipping the environment variable while using `torch.utils.pytree`. ghstack-source-id: 2718a77 Pull Request resolved: #137400
1 parent cce5093 commit 0c92aea

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

torch/utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copyreg
44
import os.path as _osp
55
import weakref
6+
from typing import TYPE_CHECKING
67

78
import torch
89
from torch.utils import (
@@ -20,6 +21,10 @@
2021
from torch.utils.throughput_benchmark import ThroughputBenchmark
2122

2223

24+
if TYPE_CHECKING:
25+
from torch.utils import pytree as pytree
26+
27+
2328
def set_module(obj, mod):
2429
"""
2530
Set the module attribute on a python object for a given object for nicer printing

torch/utils/pytree.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Contains utility functions for working with nested python data structures.
3+
4+
A *pytree* is Python nested data structure. It is a tree in the sense that
5+
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
6+
Python values. Furthermore, a pytree should not contain reference cycles.
7+
8+
pytrees are useful for working with nested collections of Tensors. For example,
9+
one can use `tree_map` to map a function over all Tensors inside some nested
10+
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
11+
inside some nested collection. pytrees are helpful for implementing nested
12+
collection support for PyTorch APIs.
13+
"""
14+
15+
import os
16+
from typing import TYPE_CHECKING
17+
18+
import torch.utils._pytree as python
19+
20+
21+
if TYPE_CHECKING:
22+
from types import ModuleType
23+
24+
import torch.utils._cxx_pytree as cxx
25+
26+
27+
__all__ = [
28+
"tree_flatten",
29+
"tree_unflatten",
30+
"tree_iter",
31+
"tree_leaves",
32+
"tree_structure",
33+
"tree_map",
34+
"tree_map_",
35+
"tree_map_only",
36+
"tree_map_only_",
37+
"tree_all",
38+
"tree_any",
39+
"tree_all_only",
40+
"tree_any_only",
41+
"treespec_pprint",
42+
]
43+
44+
45+
PYTORCH_USE_CXX_PYTREE: bool = os.getenv("PYTORCH_USE_CXX_PYTREE", "0") not in {"0", ""}
46+
47+
48+
if PYTORCH_USE_CXX_PYTREE:
49+
if not python._cxx_pytree_exists:
50+
raise ImportError(
51+
"Cannot import package `optree`. "
52+
"Please install `optree` via `python -m pip install --upgrade optree`."
53+
)
54+
55+
import torch.utils._cxx_pytree as cxx # noqa: F811
56+
57+
implementation: "ModuleType" = cxx
58+
implementation_name: str = "cxx"
59+
60+
from torch.utils._cxx_pytree import (
61+
tree_all as tree_all,
62+
tree_all_only as tree_all_only,
63+
tree_any as tree_any,
64+
tree_any_only as tree_any_only,
65+
tree_flatten as tree_flatten,
66+
tree_iter as tree_iter,
67+
tree_leaves as tree_leaves,
68+
tree_map as tree_map,
69+
tree_map_ as tree_map_,
70+
tree_map_only as tree_map_only,
71+
tree_map_only_ as tree_map_only_,
72+
tree_structure as tree_structure,
73+
tree_unflatten as tree_unflatten,
74+
treespec_pprint as treespec_pprint,
75+
)
76+
else:
77+
implementation: "ModuleType" = python # type: ignore[no-redef]
78+
implementation_name: str = "python" # type: ignore[no-redef]
79+
80+
from torch.utils._pytree import ( # type: ignore[assignment,no-redef]
81+
tree_all as tree_all,
82+
tree_all_only as tree_all_only,
83+
tree_any as tree_any,
84+
tree_any_only as tree_any_only,
85+
tree_flatten as tree_flatten,
86+
tree_iter as tree_iter,
87+
tree_leaves as tree_leaves,
88+
tree_map as tree_map,
89+
tree_map_ as tree_map_,
90+
tree_map_only as tree_map_only,
91+
tree_map_only_ as tree_map_only_,
92+
tree_structure as tree_structure,
93+
tree_unflatten as tree_unflatten,
94+
treespec_pprint as treespec_pprint,
95+
)

0 commit comments

Comments
 (0)
0