diff --git a/test/dynamo/test_flat_apply.py b/test/dynamo/test_flat_apply.py index e26bfee0bc47e..20843326125a4 100644 --- a/test/dynamo/test_flat_apply.py +++ b/test/dynamo/test_flat_apply.py @@ -147,8 +147,8 @@ def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"): t: "f32[10]" = l_x_ + l_y_ - trace_point_tensor_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_spec - trace_point_tensor_input_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_input_spec + trace_point_tensor_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_spec + trace_point_tensor_input_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_input_spec res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None return (res,) """, # NOQA: B950 diff --git a/test/export/test_swap.py b/test/export/test_swap.py index 10e003057f0a1..9a04b8228a2d6 100644 --- a/test/export/test_swap.py +++ b/test/export/test_swap.py @@ -246,11 +246,11 @@ def forward(self, x, y): _spec_0 = self._spec_0 _spec_1 = self._spec_1 _spec_4 = self._spec_4 - tree_flatten = torch.utils.pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None + tree_flatten = torch.utils.pytree.python.tree_flatten((x_1, y_1)); x_1 = y_1 = None getitem = tree_flatten[0]; tree_flatten = None x = getitem[0] y = getitem[1]; getitem = None - tree_unflatten_1 = torch.utils.pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None + tree_unflatten_1 = torch.utils.pytree.python.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None getitem_2 = getitem_1[0] getitem_3 = getitem_1[1]; getitem_1 = None @@ -258,7 +258,7 @@ def forward(self, x, y): bar = self.bar(foo); foo = None tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None - tree_unflatten = torch.utils.pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None + tree_unflatten = torch.utils.pytree.python.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None return tree_unflatten""", ) @@ -321,7 +321,7 @@ def forward(self, x, y): x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) _spec_0 = self._spec_0 _spec_3 = self._spec_3 - tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None + tree_unflatten = torch.utils.pytree.python.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None getitem = tree_unflatten[0]; tree_unflatten = None getitem_1 = getitem[0] getitem_2 = getitem[1]; getitem = None diff --git a/test/test_pytree.py b/test/test_pytree.py index 7de4ce5c77c79..3ebf16e85f9e1 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -138,6 +138,44 @@ def test_aligned_public_apis(self): ), ) + @parametrize( + "modulename", + [ + subtest("python", name="py"), + *([subtest("cxx", name="cxx")] if not IS_FBCODE else []), + ], + ) + def test_public_api_import(self, modulename): + for use_cxx_pytree in [None, "", "0", *(["1"] if not IS_FBCODE else [])]: + env = os.environ.copy() + if use_cxx_pytree is not None: + env["PYTORCH_USE_CXX_PYTREE"] = str(use_cxx_pytree) + else: + env.pop("PYTORCH_USE_CXX_PYTREE", None) + for statement in ( + f"import torch.utils.pytree.{modulename}", + f"from torch.utils.pytree import {modulename}", + f"from torch.utils.pytree.{modulename} import tree_map", + f"import torch.utils.pytree; torch.utils.pytree.{modulename}", + f"import torch.utils.pytree; torch.utils.pytree.{modulename}.tree_map", + ): + try: + subprocess.check_output( + [sys.executable, "-c", statement], + stderr=subprocess.STDOUT, + # On Windows, opening the subprocess with the default CWD makes `import torch` + # fail, so just set CWD to this script's directory + cwd=os.path.dirname(os.path.realpath(__file__)), + env=env, + ) + except subprocess.CalledProcessError as e: + self.fail( + msg=( + f"Subprocess exception while attempting to run statement `{statement}`: " + + e.output.decode("utf-8") + ) + ) + @parametrize( "pytree_impl", [ diff --git a/torch/export/_swap.py b/torch/export/_swap.py index 74b564c9fccbe..6e63e8d547b47 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -42,7 +42,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {}) %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) - %tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {}) + %tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {}) %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {}) %getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {}) %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {}) @@ -291,7 +291,7 @@ def _swap_module_helper( %y : [num_users=1] = placeholder[target=y] %_spec_0 : [num_users=1] = get_attr[target=_spec_0] - %tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {}) + %tree_unflatten : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {}) %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {}) %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {}) %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {}) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 83b288196d302..d7961a3314687 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -87,7 +87,7 @@ def _register_custom_builtin(name: str, import_str: str, obj: Any): _register_custom_builtin("torch", "import torch", torch) _register_custom_builtin("device", "from torch import device", torch.device) _register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree) -_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree) +_register_custom_builtin("pytree", "import torch.utils.pytree.python as pytree", pytree) def _is_magic(x: str) -> bool: diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 54f8b359e70dc..c24c3a3687049 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -99,6 +99,9 @@ ] +__name__ = "torch.utils.pytree.cxx" # sets the __module__ attribute of all functions in this module + + # In-tree installation may have VCS-based versioning. Update the previous static version. python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined] diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index e81ab5e4e8948..402fcd6d79b79 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -92,6 +92,9 @@ ] +__name__ = "torch.utils.pytree.python" # sets the __module__ attribute of all functions in this module + + T = TypeVar("T") S = TypeVar("S") U = TypeVar("U") diff --git a/torch/utils/pytree/__init__.py b/torch/utils/pytree/__init__.py index 3f8229778ad1b..8c5954c87c7eb 100644 --- a/torch/utils/pytree/__init__.py +++ b/torch/utils/pytree/__init__.py @@ -16,10 +16,10 @@ import os as _os import sys as _sys -from typing import Any as _Any, Optional as _Optional +from types import ModuleType as _ModuleType +from typing import Any as _Any, Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING import torch.utils._pytree as python -from torch.utils._exposed_in import exposed_in as _exposed_in from torch.utils._pytree import ( # these type aliases are identical in both implementations FlattenFunc, FlattenWithKeysFunc, @@ -30,6 +30,10 @@ ) +if _TYPE_CHECKING: + import torch.utils._cxx_pytree as cxx + + __all__ = [ "PyTreeSpec", "register_pytree_node", @@ -64,9 +68,7 @@ } -if PYTORCH_USE_CXX_PYTREE: - import torch.utils._cxx_pytree as cxx # noqa: F401 - +def _import_cxx_pytree_and_store() -> _ModuleType: if not python._cxx_pytree_dynamo_traceable: raise ImportError( "Cannot import package `optree`. " @@ -74,8 +76,65 @@ "Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`." ) + import torch.utils._cxx_pytree as cxx + + # This allows the following statements to work properly: + # + # import torch.utils.pytree + # + # torch.utils.pytree.cxx + # torch.utils.pytree.cxx.tree_map + # + _sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx + return cxx + + +if PYTORCH_USE_CXX_PYTREE: + cxx = _import_cxx_pytree_and_store() # noqa: F811 +else: + cxx = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment] + + +_sys.modules[f"{__name__}.python"] = python +if cxx is not None: + _sys.modules[f"{__name__}.cxx"] = cxx +else: + del cxx + + class LazyCxxModule(_ModuleType): + def __getattr__(self, name: str) -> _Any: + if name == "__name__": + return f"{__name__}.cxx" + if name == "__file__": + return python.__file__.removesuffix("_python.py") + "_cxx_pytree.py" + + cxx = globals().get("cxx") + if cxx is None: + if name.startswith("_"): + raise AttributeError( + f"module {self.__name__!r} has not been imported yet: " + f"accessing attribute {name!r}. " + f"Please import {self.__name__!r} explicitly first." + ) + + # Lazy import on first member access + cxx = _import_cxx_pytree_and_store() + + return getattr(cxx, name) + + def __setattr__(self, name: str, value: _Any) -> None: + # Lazy import + cxx = _import_cxx_pytree_and_store() + return setattr(cxx, name, value) + + # This allows the following statements to work properly: + # + # import torch.utils.pytree.cxx + # from torch.utils.pytree.cxx import tree_map + # + _sys.modules[f"{__name__}.cxx"] = LazyCxxModule(f"{__name__}.cxx") -_sys.modules[f"{__name__}.cxx"] = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment] + del LazyCxxModule if not PYTORCH_USE_CXX_PYTREE: @@ -103,8 +162,6 @@ tree_unflatten, treespec_pprint, ) - - PyTreeSpec = _exposed_in(__name__)(PyTreeSpec) # type: ignore[misc] else: from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef] is_namedtuple, @@ -132,41 +189,6 @@ ) -# Change `__module__` of reexported public APIs to 'torch.utils.pytree' -__func_names = frozenset( - { - "tree_all", - "tree_all_only", - "tree_any", - "tree_any_only", - "tree_flatten", - "tree_iter", - "tree_leaves", - "tree_map", - "tree_map_", - "tree_map_only", - "tree_map_only_", - "tree_structure", - "tree_unflatten", - "treespec_pprint", - "is_namedtuple", - "is_namedtuple_class", - "is_namedtuple_instance", - "is_structseq", - "is_structseq_class", - "is_structseq_instance", - } -) -globals().update( - { - name: _exposed_in(__name__)(member) - for name, member in globals().items() - if name in __func_names - } -) -del __func_names, _exposed_in - - def register_pytree_node( cls: type[_Any], /, @@ -208,9 +230,6 @@ def register_pytree_node( def __getattr__(name: str) -> _Any: if name == "cxx": # Lazy import - import torch.utils._cxx_pytree as cxx # noqa: F811 - - _sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx - return cxx + return _import_cxx_pytree_and_store() raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/torch/utils/pytree/cxx.pyi b/torch/utils/pytree/cxx.pyi index 1c8a69c4bb429..d7f95d883a1e6 100644 --- a/torch/utils/pytree/cxx.pyi +++ b/torch/utils/pytree/cxx.pyi @@ -1,7 +1,7 @@ # Owner(s): ["module: pytree"] -from .._cxx_pytree import * # noqa: F403 -from .._cxx_pytree import ( +from .._cxx_pytree import * # previously public APIs # noqa: F403 +from .._cxx_pytree import ( # non-public internal APIs __all__ as __all__, _broadcast_to_and_flatten as _broadcast_to_and_flatten, KeyPath as KeyPath, diff --git a/torch/utils/pytree/python.pyi b/torch/utils/pytree/python.pyi index 61f72707cd055..f7e647e4b08ff 100644 --- a/torch/utils/pytree/python.pyi +++ b/torch/utils/pytree/python.pyi @@ -1,7 +1,7 @@ # Owner(s): ["module: pytree"] -from .._pytree import * # noqa: F403 -from .._pytree import ( +from .._pytree import * # previously public APIs # noqa: F403 +from .._pytree import ( # non-public internal APIs __all__ as __all__, _broadcast_to_and_flatten as _broadcast_to_and_flatten, arg_tree_leaves as arg_tree_leaves,