8000 [pytree] simplify public API exposition with `__module__` by XuehaiPan · Pull Request #148328 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[pytree] simplify public API exposition with __module__ #148328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 70 commits into
base: gh/XuehaiPan/253/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
5869417
Update
XuehaiPan Mar 3, 2025
c45807e
Update
XuehaiPan Mar 3, 2025
ae1f74d
Update
XuehaiPan Mar 3, 2025
15f6dfa
Update
XuehaiPan Mar 3, 2025
c76deaa
Update
XuehaiPan Mar 3, 2025
085e755
Update
XuehaiPan Mar 3, 2025
f1b7f65
Update
XuehaiPan Mar 3, 2025
322b9e6
Update
XuehaiPan Mar 3, 2025
a5c44e4
Update
XuehaiPan Mar 3, 2025
715d6bf
Update
XuehaiPan Mar 3, 2025
cfebbe1
Update
XuehaiPan Mar 3, 2025
a473599
Update
XuehaiPan Mar 3, 2025
b6f8a95
Update
XuehaiPan Mar 3, 2025
67cd29c
Update
XuehaiPan Mar 4, 2025
e14940d
Update
XuehaiPan Mar 4, 2025
1eb9443
Update
XuehaiPan Mar 4, 2025
4c10ac5
Update
XuehaiPan Mar 4, 2025
57bd9bc
Update
XuehaiPan Mar 4, 2025
8ff9e5c
Update
XuehaiPan Mar 4, 2025
8ba9e80
Update
XuehaiPan Mar 4, 2025
b34ea1d
Update
XuehaiPan Mar 4, 2025
9d095c8
Update
XuehaiPan Mar 4, 2025
cd3b63d
Update
XuehaiPan Mar 4, 2025
bd50bf5
Update
XuehaiPan Mar 4, 2025
2f55ede
Update
XuehaiPan Mar 4, 2025
7880252
Update
XuehaiPan Mar 5, 2025
64c9bf6
Update
XuehaiPan Mar 5, 2025
ad363d1
Update
XuehaiPan Mar 5, 2025
fb42877
Update
XuehaiPan Mar 5, 2025
6eff9de
Update
XuehaiPan Mar 5, 2025
0757d9d
Update
XuehaiPan Mar 5, 2025
fdd8622
Update
XuehaiPan Mar 5, 2025
376677f
Update
XuehaiPan Mar 5, 2025
e89cb35
Update
XuehaiPan Mar 5, 2025
96f6b30
Update
XuehaiPan Mar 6, 2025
0046802
Update
XuehaiPan Mar 6, 2025
51f91d7
Update
XuehaiPan Mar 6, 2025
176e34a
Update
XuehaiPan Mar 7, 2025
877a79e
Update
XuehaiPan Mar 12, 2025
8e0f02a
Update
XuehaiPan Mar 14, 2025
63874e5
Update
XuehaiPan Mar 20, 2025
e053358
Update
XuehaiPan Apr 1, 2025
8400373
Update
XuehaiPan Apr 3, 2025
ec05c85
Update
XuehaiPan Apr 3, 2025
10815ff
Update
XuehaiPan Apr 3, 2025
e2843b9
Update
XuehaiPan Apr 5, 2025
3fa91e0
Update
XuehaiPan Apr 7, 2025
874a8c1
Update
XuehaiPan Apr 10, 2025
b380eee
Update
XuehaiPan Apr 11, 2025
92d4ac0
Update
XuehaiPan Apr 11, 2025
22feaf3
Update
XuehaiPan Apr 11, 2025
7f7de4f
Update
XuehaiPan Apr 11, 2025
80b28bb
Update
XuehaiPan Apr 15, 2025
c562fb9
Update
XuehaiPan Apr 15, 2025
00d35d2
Update
XuehaiPan Apr 15, 2025
22166a8
Update
XuehaiPan Apr 15, 2025
f7ca766
Update
XuehaiPan Apr 23, 2025
ca4a9b4
Update
XuehaiPan Apr 26, 2025
a5ecb1f
Update
XuehaiPan May 1, 2025
c86c47f
Update
XuehaiPan May 1, 2025
5ece761
Update
XuehaiPan May 1, 2025
eb075c0
Update
XuehaiPan May 1, 2025
66ef529
Update
XuehaiPan May 2, 2025
11501ce
Update
XuehaiPan May 2, 2025
ece5408
Update
XuehaiPan May 2, 2025
4eaaa56
Update
XuehaiPan May 2, 2025
0c1cda4
Update
XuehaiPan May 3, 2025
e45acce
Update
XuehaiPan May 8, 2025
c06df34
Update
XuehaiPan May 14, 2025
7e34dae
Update
XuehaiPan May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/dynamo/test_flat_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"):

t: "f32[10]" = l_x_ + l_y_

trace_point_tensor_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_input_spec
trace_point_tensor_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_input_spec
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
return (res,)
""", # NOQA: B950
Expand Down
8 changes: 4 additions & 4 deletions test/export/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,19 +246,19 @@ def forward(self, x, y):
_spec_0 = self._spec_0
_spec_1 = self._spec_1
_spec_4 = self._spec_4
tree_flatten = torch.utils.pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
tree_flatten = torch.utils.pytree.python.tree_flatten((x_1, y_1)); x_1 = y_1 = None
getitem = tree_flatten[0]; tree_flatten = None
x = getitem[0]
y = getitem[1]; getitem = None
tree_unflatten_1 = torch.utils.pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
tree_unflatten_1 = torch.utils.pytree.python.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None
getitem_2 = getitem_1[0]
getitem_3 = getitem_1[1]; getitem_1 = None
foo = self.foo(getitem_2, getitem_3); getitem_2 = getitem_3 = None
bar = self.bar(foo); foo = None
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None
getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
tree_unflatten = torch.utils.pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
tree_unflatten = torch.utils.pytree.python.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
return tree_unflatten""",
)

Expand Down Expand Up @@ -321,7 +321,7 @@ def forward(self, x, y):
x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
_spec_0 = self._spec_0
_spec_3 = self._spec_3
tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
tree_unflatten = torch.utils.pytree.python.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
getitem = tree_unflatten[0]; tree_unflatten = None
getitem_1 = getitem[0]
getitem_2 = getitem[1]; getitem = None
Expand Down
38 changes: 38 additions & 0 deletions test/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,44 @@ def test_aligned_public_apis(self):
),
)

@parametrize(
"modulename",
[
subtest("python", name="py"),
*([subtest("cxx", name="cxx")] if not IS_FBCODE else []),
],
)
def test_public_api_import(self, modulename):
for use_cxx_pytree in [None, "", "0", *(["1"] if not IS_FBCODE else [])]:
env = os.environ.copy()
if use_cxx_pytree is not None:
env["PYTORCH_USE_CXX_PYTREE"] = str(use_cxx_pytree)
else:
env.pop("PYTORCH_USE_CXX_PYTREE", None)
for statement in (
f"import torch.utils.pytree.{modulename}",
f"from torch.utils.pytree import {modulename}",
f"from torch.utils.pytree.{modulename} import tree_map",
f"import torch.utils.pytree; torch.utils.pytree.{modulename}",
f"import torch.utils.pytree; torch.utils.pytree.{modulename}.tree_map",
):
try:
subprocess.check_output(
[sys.executable, "-c", statement],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),
env=env,
)
except subprocess.CalledProcessError as e:
self.fail(
msg=(
f"Subprocess exception while attempting to run statement `{statement}`: "
+ e.output.decode("utf-8")
)
)

@parametrize(
"pytree_impl",
[
Expand Down
4 changes: 2 additions & 2 deletions torch/export/_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
%getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
Expand Down Expand Up @@ -291,7 +291,7 @@ def _swap_module_helper(
%y : [num_users=1] = placeholder[target=y]

%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
%tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
%tree_unflatten : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _register_custom_builtin(name: str, import_str: str, obj: Any):
_register_custom_builtin("torch", "import torch", torch)
_register_custom_builtin("device", "from torch import device", torch.device)
_register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree)
_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree)
_register_custom_builtin("pytree", "import torch.utils.pytree.python as pytree", pytree)


def _is_magic(x: str) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions torch/utils/_cxx_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@
]


__name__ = "torch.utils.pytree.cxx" # sets the __module__ attribute of all functions in this module


# In-tree installation may have VCS-based versioning. Update the previous static version.
python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined]

Expand Down
3 changes: 3 additions & 0 deletions torch/utils/_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@
]


__name__ = "torch.utils.pytree.python" # sets the __module__ attribute of all functions in this module


T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")
Expand Down
113 changes: 66 additions & 47 deletions torch/utils/pytree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import os as _os
import sys as _sys
from typing import Any as _Any, Optional as _Optional
from types import ModuleType as _ModuleType
from typing import Any as _Any, Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING

import torch.utils._pytree as python
from torch.utils._exposed_in import exposed_in as _exposed_in
from torch.utils._pytree import ( # these type aliases are identical in both implementations
FlattenFunc,
FlattenWithKeysFunc,
Expand All @@ -30,6 +30,10 @@
)


if _TYPE_CHECKING:
import torch.utils._cxx_pytree as cxx


__all__ = [
"PyTreeSpec",
"register_pytree_node",
Expand Down Expand Up @@ -64,18 +68,73 @@
}


if PYTORCH_USE_CXX_PYTREE:
import torch.utils._cxx_pytree as cxx # noqa: F401

def _import_cxx_pytree_and_store() -> _ModuleType:
if not python._cxx_pytree_dynamo_traceable:
raise ImportError(
"Cannot import package `optree`. "
"Please install `optree` via `python -m pip install --upgrade optree`. "
"Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`."
)

import torch.utils._cxx_pytree as cxx

# This allows the following statements to work properly:
#
# import torch.utils.pytree
#
# torch.utils.pytree.cxx
# torch.utils.pytree.cxx.tree_map
#
_sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
return cxx


if PYTORCH_USE_CXX_PYTREE:
cxx = _import_cxx_pytree_and_store() # noqa: F811
else:
cxx = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]


_sys.modules[f"{__name__}.python"] = python
if cxx is not None:
_sys.modules[f"{__name__}.cxx"] = cxx
else:
del cxx

class LazyCxxModule(_ModuleType):
def __getattr__(self, name: str) -> _Any:
if name == "__name__":
return f"{__name__}.cxx"
if name == "__file__":
return python.__file__.removesuffix("_python.py") + "_cxx_pytree.py"

cxx = globals().get("cxx")
if cxx is None:
if name.startswith("_"):
raise AttributeError(
f"module {self.__name__!r} has not been imported yet: "
f"accessing attribute {name!r}. "
f"Please import {self.__name__!r} explicitly first."
)

# Lazy import on first member access
cxx = _import_cxx_pytree_and_store()

return getattr(cxx, name)

def __setattr__(self, name: str, value: _Any) -> None:
# Lazy import
cxx = _import_cxx_pytree_and_store()
return setattr(cxx, name, value)

# This allows the following statements to work properly:
#
# import torch.utils.pytree.cxx
# from torch.utils.pytree.cxx import tree_map
#
_sys.modules[f"{__name__}.cxx"] = LazyCxxModule(f"{__name__}.cxx")

_sys.modules[f"{__name__}.cxx"] = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]
del LazyCxxModule


if not PYTORCH_USE_CXX_PYTREE:
Expand Down Expand Up @@ -103,8 +162,6 @@
tree_unflatten,
treespec_pprint,
)

PyTreeSpec = _exposed_in(__name__)(PyTreeSpec) # type: ignore[misc]
else:
from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef]
is_namedtuple,
Expand Down Expand Up @@ -132,41 +189,6 @@
)


# Change `__module__` of reexported public APIs to 'torch.utils.pytree'
__func_names = frozenset(
{
"tree_all",
"tree_all_only",
"tree_any",
"tree_any_only",
"tree_flatten",
"tree_iter",
"tree_leaves",
"tree_map",
"tree_map_",
"tree_map_only",
"tree_map_only_",
"tree_structure",
"tree_unflatten",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
}
)
globals().update(
{
name: _exposed_in(__name__)(member)
for name, member in globals().items()
if name in __func_names
}
)
del __func_names, _exposed_in


def register_pytree_node(
cls: type[_Any],
/,
Expand Down Expand Up @@ -208,9 +230,6 @@ def register_pytree_node(
def __getattr__(name: str) -> _Any:
if name == "cxx":
# Lazy import
import torch.utils._cxx_pytree as cxx # noqa: F811

_sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
return cxx
return _import_cxx_pytree_and_store()

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
4 changes: 2 additions & 2 deletions torch/utils/pytree/cxx.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Owner(s): ["module: pytree"]

from .._cxx_pytree import * # noqa: F403
from .._cxx_pytree import (
from .._cxx_pytree import * # previously public APIs # noqa: F403
from .._cxx_pytree import ( # non-public internal APIs
__all__ as __all__,
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
KeyPath as KeyPath,
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/pytree/python.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Owner(s): ["module: pytree"]

from .._pytree import * # noqa: F403
from .._pytree import (
from .._pytree import * # previously public APIs # noqa: F403
from .._pytree import ( # non-public internal APIs
__all__ as __all__,
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
arg_tree_leaves as arg_tree_leaves,
Expand Down
Loading
0