8000 [pytree] simplify public API exposition with `__module__` · pytorch/pytorch@d4c4b84 · GitHub
[go: up one dir, main page]

Skip to content

Commit d4c4b84

Browse files
committed
[pytree] simplify public API exposition with __module__
ghstack-source-id: 031e10f Pull Request resolved: #148328
1 parent 9ba6927 commit d4c4b84

File tree

10 files changed

+123
-60
lines changed

10 files changed

+123
-60
lines changed

test/dynamo/test_flat_apply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"):
147147
148148
t: "f32[10]" = l_x_ + l_y_
149149
150-
trace_point_tensor_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_spec
151-
trace_point_tensor_input_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_input_spec
150+
trace_point_tensor_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_spec
151+
trace_point_tensor_input_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_input_spec
152152
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
153153
return (res,)
154154
""", # NOQA: B950

test/export/test_swap.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,19 +246,19 @@ def forward(self, x, y):
246246
_spec_0 = self._spec_0
247247
_spec_1 = self._spec_1
248248
_spec_4 = self._spec_4
249-
tree_flatten = torch.utils.pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
249+
tree_flatten = torch.utils.pytree.python.tree_flatten((x_1, y_1)); x_1 = y_1 = None
250250
getitem = tree_flatten[0]; tree_flatten = None
251251
x = getitem[0]
252252
y = getitem[1]; getitem = None
253-
tree_unflatten_1 = torch.utils.pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
253+
tree_unflatten_1 = torch.utils.pytree.python.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
254254
getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None
255255
getitem_2 = getitem_1[0]
256256
getitem_3 = getitem_1[1]; getitem_1 = None
257257
foo = self.foo(getitem_2, getitem_3); getitem_2 = getitem_3 = None
258258
bar = self.bar(foo); foo = None
259259
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None
260260
getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
261-
tree_unflatten = torch.utils.pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
261+
tree_unflatten = torch.utils.pytree.python.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
262262
return tree_unflatten""",
263263
)
264264

@@ -321,7 +321,7 @@ def forward(self, x, y):
321321
x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
322322
_spec_0 = self._spec_0
323323
_spec_3 = self._spec_3
324-
tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
324+
tree_unflatten = torch.utils.pytree.python.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
325325
getitem = tree_unflatten[0]; tree_unflatten = None
326326
getitem_1 = getitem[0]
327327
getitem_2 = getitem[1]; getitem = None

test/test_pytree.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,44 @@ def test_aligned_public_apis(self):
138138
),
139139
)
140140

141+
@parametrize(
142+
"modulename",
143+
[
144+
subtest("python", name="py"),
145+
*([subtest("cxx", name="cxx")] if not IS_FBCODE else []),
146+
],
147+
)
148+
def test_public_api_import(self, modulename):
149+
for use_cxx_pytree in [None, "", "0", *(["1"] if not IS_FBCODE else [])]:
150+
env = os.environ.copy()
151+
if use_cxx_pytree is not None:
152+
env["PYTORCH_USE_CXX_PYTREE"] = str(use_cxx_pytree)
153+
else:
154+
env.pop("PYTORCH_USE_CXX_PYTREE", None)
155+
for statement in (
156+
f"import torch.utils.pytree.{modulename}",
157+
f"from torch.utils.pytree import {modulename}",
158+
f"from torch.utils.pytree.{modulename} import tree_map",
159+
f"import torch.utils.pytree; torch.utils.pytree.{modulename}",
160+
f"import torch.utils.pytree; torch.utils.pytree.{modulename}.tree_map",
161+
):
162+
try:
163+
subprocess.check_output(
164+
[sys.executable, "-c", statement],
165+
stderr=subprocess.STDOUT,
166+
# On Windows, opening the subprocess with the default CWD makes `import torch`
167+
# fail, so just set CWD to this script's directory
168+
cwd=os.path.dirname(os.path.realpath(__file__)),
169+
env=env,
170+
)
171+
except subprocess.CalledProcessError as e:
172+
self.fail(
173+
msg=(
174+
f"Subprocess exception while attempting to run statement `{statement}`: "
175+
+ e.output.decode("utf-8")
176+
)
177+
)
178+
141179
@parametrize(
142180
"pytree_impl",
143181
[

torch/export/_swap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
4242
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
4343
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
4444
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
45-
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
45+
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
4646
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
4747
%getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
4848
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
@@ -291,7 +291,7 @@ def _swap_module_helper(
291291
%y : [num_users=1] = placeholder[target=y]
292292
293293
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
294-
%tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
294+
%tree_unflatten : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
295295
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
296296
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
297297
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})

torch/fx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _register_custom_builtin(name: str, import_str: str, obj: Any):
8787
_register_custom_builtin("torch", "import torch", torch)
8888
_register_custom_builtin("device", "from torch import device", torch.device)
8989
_register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree)
90-
_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree)
90+
_register_custom_builtin("pytree", "import torch.utils.pytree.python as pytree", pytree)
9191

9292

9393
def _is_magic(x: str) -> bool:

torch/utils/_cxx_pytree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@
9999
]
100100

101101

102+
__name__ = "torch.utils.pytree.cxx" # sets the __module__ attribute of all functions in this module
103+
104+
102105
# In-tree installation may have VCS-based versioning. Update the previous static version.
103106
python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined]
104107

2254

torch/utils/_pytree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@
9292
]
9393

9494

95+
__name__ = "torch.utils.pytree.python" # sets the __module__ attribute of all functions in this module
96+
97+
9598
T = TypeVar("T")
9699
S = TypeVar("S")
97100
U = TypeVar("U")

torch/utils/pytree/__init__.py

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
import os as _os
1818
import sys as _sys
19-
from typing import Any as _Any, Optional as _Optional
19+
from types import ModuleType as _ModuleType
20+
from typing import Any as _Any, Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING
2021

2122
import torch.utils._pytree as python
22-
from torch.utils._exposed_in import exposed_in as _exposed_in
2323
from torch.utils._pytree import ( # these type aliases are identical in both implementations
2424
FlattenFunc,
2525
FlattenWithKeysFunc,
@@ -30,6 +30,10 @@
3030
)
3131

3232

33+
if _TYPE_CHECKING:
34+
import torch.utils._cxx_pytree as cxx
35+
36+
3337
__all__ = [
3438
"PyTreeSpec",
3539
"register_pytree_node",
@@ -64,18 +68,73 @@
6468
}
6569

6670

67-
if PYTORCH_USE_CXX_PYTREE:
68-
import torch.utils._cxx_pytree as cxx # noqa: F401
69-
71+
def _import_cxx_pytree_and_store() -> _ModuleType:
7072
if not python._cxx_pytree_dynamo_traceable:
7173
raise ImportError(
7274
"Cannot import package `optree`. "
7375
"Please install `optree` via `python -m pip install --upgrade optree`. "
7476
"Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`."
7577
)
7678

79+
import torch.utils._cxx_pytree as cxx
80+
81+
# This allows the following statements to work properly:
82+
#
83+
# import torch.utils.pytree
84+
#
85+
# torch.utils.pytree.cxx
86+
# torch.utils.pytree.cxx.tree_map
87+
#
88+
_sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
89+
return cxx
90+
91+
92+
if PYTORCH_USE_CXX_PYTREE:
93+
cxx = _import_cxx_pytree_and_store() # noqa: F811
94+
else:
95+
cxx = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]
96+
97+
98+
_sys.modules[f"{__name__}.python"] = python
99+
if cxx is not None:
100+
_sys.modules[f"{__name__}.cxx"] = cxx
101+
else:
102+
del cxx
103+
104+
class LazyCxxModule(_ModuleType):
105+
def __getattr__(self, name: str) -> _Any:
106+
if name == "__name__":
107+
return f"{__name__}.cxx"
108+
if name == "__file__":
109+
return python.__file__.removesuffix("_python.py") + "_cxx_pytree.py"
110+
111+
cxx = globals().get("cxx")
112+
if cxx is None:
113+
if name.startswith("_"):
114+
raise AttributeError(
115+
f"module {self.__name__!r} has not been imported yet: "
116+
f"accessing attribute {name!r}. "
117+
f"Please import {self.__name__!r} explicitly first."
118+
)
119+
120+
# Lazy import on first member access
121+
cxx = _import_cxx_pytree_and_store()
122+
123+
return getattr(cxx, name)
124+
125+
def __setattr__(self, name: str, value: _Any) -> None:
126+
# Lazy import
127+
cxx = _import_cxx_pytree_and_store()
128+
return setattr(cxx, name, value)
129+
130+
# This allows the following statements to work properly:
131+
#
132+
# import torch.utils.pytree.cxx
133+
# from torch.utils.pytree.cxx import tree_map
134+
#
135+
_sys.modules[f"{__name__}.cxx"] = LazyCxxModule(f"{__name__}.cxx")
77136

78-
_sys.modules[f"{__name__}.cxx"] = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]
137+
del LazyCxxModule
79138

80139

81140
if not PYTORCH_USE_CXX_PYTREE:
@@ -103,8 +162,6 @@
103162
tree_unflatten,
104163
treespec_pprint,
105164
)
106-
107-
PyTreeSpec = _exposed_in(__name__)(PyTreeSpec) # type: ignore[misc]
108165
else:
109166
from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef]
110167
is_namedtuple,
@@ -132,41 +189,6 @@
132189
)
133190

134191

135-
# Change `__module__` of reexported public APIs to 'torch.utils.pytree'
136-
__func_names = frozenset(
137-
{
138-
"tree_all",
139-
"tree_all_only",
140-
"tree_any",
141-
"tree_any_only",
142-
"tree_flatten",
143-
"tree_iter",
144-
"tree_leaves",
145-
"tree_map",
146-
"tree_map_",
147-
"tree_map_only",
148-
"tree_map_only_",
149-
"tree_structure",
150-
"tree_unflatten",
151-
"treespec_pprint",
152-
"is_namedtuple",
153-
"is_namedtuple_class",
154-
"is_namedtuple_instance",
155-
"is_structseq",
156-
"is_structseq_class",
157-
"is_structseq_instance",
158-
}
159-
)
160-
globals().update(
161-
{
162-
name: _exposed_in(__name__)(member)
163-
for name, member in globals().items()
164-
if name in __func_names
165-
}
166-
)
167-
del __func_names, _exposed_in
168-
169-
170192
def register_pytree_node(
171193
cls: type[_Any],
172194
/,
@@ -208,9 +230,6 @@ def register_pytree_node(
208230
def __getattr__(name: str) -> _Any:
209231
if name == "cxx":
210232
# Lazy import
211-
import torch.utils._cxx_pytree as cxx # noqa: F811
212-
213-
_sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
214-
return cxx
233+
return _import_cxx_pytree_and_store()
215234

216235
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

torch/utils/pytree/cxx.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Owner(s): ["module: pytree"]
22

3-
from .._cxx_pytree import * # noqa: F403
4-
from .._cxx_pytree import (
3+
from .._cxx_pytree import * # previously public APIs # noqa: F403
4+
from .._cxx_pytree import ( # non-public internal APIs
55
__all__ as __all__,
66
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
77
KeyPath as KeyPath,

torch/utils/pytree/python.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Owner(s): ["module: pytree"]
22

3-
from .._pytree import * # noqa: F403
4-
from .._pytree import (
3+
from .._pytree import * # previously public APIs # noqa: F403
4+
from .._pytree import ( # non-public internal APIs
55
__all__ as __all__,
66
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
77
arg_tree_leaves as arg_tree_leaves,

0 commit comments

Comments
 (0)
0