8000 [pytree] simplify public API exposition with `__module__` · pytorch/pytorch@d4c4b84 · GitHub
[go: up one dir, main page]

Skip to content

Commit d4c4b84

Browse files
committed
[pytree] simplify public API exposition with __module__
ghstack-source-id: 031e10f Pull Request resolved: #148328
1 parent 9ba6927 commit d4c4b84

File tree

10 files changed

+123
-60
lines changed
  • test
    • dynamo
    • export
      • < 8000 div class="PRIVATE_VisuallyHidden prc-TreeView-TreeViewVisuallyHidden-4-mPv" aria-hidden="true" id=":R5mddddabH1:">
        test_swap.py
  • torch
  • 10 files changed

    +123
    -60
    lines changed

    test/dynamo/test_flat_apply.py

    +2-2
    Original file line numberDiff line numberDiff line change
    @@ -147,8 +147,8 @@ def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"):
    147147
    148148
    t: "f32[10]" = l_x_ + l_y_
    149149
    150-
    trace_point_tensor_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_spec
    151-
    trace_point_tensor_input_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_input_spec
    150+
    trace_point_tensor_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_spec
    151+
    trace_point_tensor_input_spec : torch.utils.pytree.python.PyTreeSpec = self.trace_point_tensor_input_spec
    152152
    res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
    153153
    return (res,)
    154154
    """, # NOQA: B950

    test/export/test_swap.py

    +4-4
    Original file line numberDiff line numberDiff line change
    @@ -246,19 +246,19 @@ def forward(self, x, y):
    246246
    _spec_0 = self._spec_0
    247247
    _spec_1 = self._spec_1
    248248
    _spec_4 = self._spec_4
    249-
    tree_flatten = torch.utils.pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
    249+
    tree_flatten = torch.utils.pytree.python.tree_flatten((x_1, y_1)); x_1 = y_1 = None
    250250
    getitem = tree_flatten[0]; tree_flatten = None
    251251
    x = getitem[0]
    252252
    y = getitem[1]; getitem = None
    253-
    tree_unflatten_1 = torch.utils.pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
    253+
    tree_unflatten_1 = torch.utils.pytree.python.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
    254254
    getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None
    255255
    getitem_2 = getitem_1[0]
    256256
    getitem_3 = getitem_1[1]; getitem_1 = None
    257257
    foo = self.foo(getitem_2, getitem_3); getitem_2 = getitem_3 = None
    258258
    bar = self.bar(foo); foo = None
    259259
    tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None
    260260
    getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
    261-
    tree_unflatten = torch.utils.pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
    261+
    tree_unflatten = torch.utils.pytree.python.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
    262262
    return tree_unflatten""",
    263263
    )
    264264

    @@ -321,7 +321,7 @@ def forward(self, x, y):
    321321
    x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
    322322
    _spec_0 = self._spec_0
    323323
    _spec_3 = self._spec_3
    324-
    tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
    324+
    tree_unflatten = torch.utils.pytree.python.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
    325325
    getitem = tree_unflatten[0]; tree_unflatten = None
    326326
    getitem_1 = getitem[0]
    327327
    getitem_2 = getitem[1]; getitem = None

    test/test_pytree.py

    +38
    Original file line numberDiff line numberDiff line change
    @@ -138,6 +138,44 @@ def test_aligned_public_apis(self):
    138138
    ),
    139139
    )
    140140

    141+
    @parametrize(
    142+
    "modulename",
    143+
    [
    144+
    subtest("python", name="py"),
    145+
    *([subtest("cxx", name="cxx")] if not IS_FBCODE else []),
    146+
    ],
    147+
    )
    148+
    def test_public_api_import(self, modulename):
    149+
    for use_cxx_pytree in [None, "", "0", *(["1"] if not IS_FBCODE else [])]:
    150+
    env = os.environ.copy()
    151+
    if use_cxx_pytree is not None:
    152+
    env["PYTORCH_USE_CXX_PYTREE"] = str(use_cxx_pytree)
    153+
    else:
    154+
    env.pop("PYTORCH_USE_CXX_PYTREE", None)
    155+
    for statement in (
    156+
    f"import torch.utils.pytree.{modulename}",
    157+
    f"from torch.utils.pytree import {modulename}",
    158+
    f"from torch.utils.pytree.{modulename} import tree_map",
    159+
    f"import torch.utils.pytree; torch.utils.pytree.{modulename}",
    160+
    f"import torch.utils.pytree; torch.utils.pytree.{modulename}.tree_map",
    161+
    ):
    162+
    try:
    163+
    subprocess.check_output(
    164+
    [sys.executable, "-c", statement],
    165+
    stderr=subprocess.STDOUT,
    166+
    # On Windows, opening the subprocess with the default CWD makes `import torch`
    167+
    # fail, so just set CWD to this script's directory
    168+
    cwd=os.path.dirname(os.path.realpath(__file__)),
    169+
    env=env,
    170+
    )
    171+
    except subprocess.CalledProcessError as e:
    172+
    self.fail(
    173+
    msg=(
    174+
    f"Subprocess exception while attempting to run statement `{statement}`: "
    175+
    + e.output.decode("utf-8")
    176+
    )
    177+
    )
    178+
    141179
    @parametrize(
    142180
    "pytree_impl",
    143181
    [

    torch/export/_swap.py

    +2-2
    Original file line numberDiff line numberDiff line change
    @@ -42,7 +42,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
    4242
    %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
    4343
    %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
    4444
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
    45-
    %tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
    45+
    %tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
    4646
    %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
    4747
    %getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
    4848
    %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
    @@ -291,7 +291,7 @@ def _swap_module_helper(
    291291
    %y : [num_users=1] = placeholder[target=y]
    292292
    293293
    %_spec_0 : [num_users=1] = get_attr[target=_spec_0]
    294-
    %tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
    294+
    %tree_unflatten : [num_users=2] = call_function[target=torch.utils.pytree.python.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
    295295
    %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
    296296
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
    297297
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})

    torch/fx/graph.py

    +1-1
    Original file line numberDiff line numberDiff line change
    @@ -87,7 +87,7 @@ def _register_custom_builtin(name: str, import_str: str, obj: Any):
    8787
    _register_custom_builtin("torch", "import torch", torch)
    8888
    _register_custom_builtin("device", "from torch import device", torch.device)
    8989
    _register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree)
    90-
    _register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree)
    90+
    _register_custom_builtin("pytree", "import torch.utils.pytree.python as pytree", pytree)
    9191

    9292

    9393
    def _is_magic(x: str) -> bool:

    torch/utils/_cxx_pytree.py

    +3
    Original file line numberDiff line numberDiff line change
    @@ -99,6 +99,9 @@
    9999
    ]
    100100

    101101

    102+
    __name__ = "torch.utils.pytree.cxx" # sets the __module__ attribute of all functions in this module
    103+
    104+
    102105
    # In-tree installation may have VCS-based versioning. Update the previous static version.
    103106
    python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined]
    104107

    torch/utils/_pytree.py

    +3
    Original file line numberDiff line numberDiff line change
    @@ -92,6 +92,9 @@
    9292
    ]
    9393

    9494

    95+
    __name__ = "torch.utils.pytree.python" # sets the __module__ attribute of all functions in this module
    96+
    97+
    9598
    T = TypeVar("T")
    9699
    S = TypeVar("S")
    97100
    U = TypeVar("U")

    torch/utils/pytree/__init__.py

    +66-47
    Original file line numberDiff line numberDiff line change
    @@ -16,10 +16,10 @@
    1616

    1717
    import os as _os
    1818
    import sys as _sys
    19-
    from typing import Any as _Any, Optional as _Optional
    19+
    from types import ModuleType as _ModuleType
    20+
    from typing import Any as _Any, Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING
    2021

    2122
    import torch.utils._pytree as python
    22-
    from torch.utils._exposed_in import exposed_in as _exposed_in
    2323
    from torch.utils._pytree import ( # these type aliases are identical in both implementations
    2424
    FlattenFunc,
    2525
    FlattenWithKeysFunc,
    @@ -30,6 +30,10 @@
    3030
    )
    3131

    3232

    33+
    if _TYPE_CHECKING:
    34+
    import torch.utils._cxx_pytree as cxx
    35+
    36+
    3337
    __all__ = [
    3438
    "PyTreeSpec",
    3539
    "register_pytree_node",
    @@ -64,18 +68,73 @@
    6468
    }
    6569

    6670

    67-
    if PYTORCH_USE_CXX_PYTREE:
    68-
    import torch.utils._cxx_pytree as cxx # noqa: F401
    69-
    71+
    def _import_cxx_pytree_and_store() -> _ModuleType:
    7072
    if not python._cxx_pytree_dynamo_traceable:
    7173
    raise ImportError(
    7274
    "Cannot import package `optree`. "
    7375
    "Please install `optree` via `python -m pip install --upgrade optree`. "
    7476
    "Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`."
    7577
    )
    7678

    79+
    import torch.utils._cxx_pytree as cxx
    80+
    81+
    # This allows the following statements to work properly:
    82+
    #
    83+
    # import torch.utils.pytree
    84+
    #
    85+
    # torch.utils.pytree.cxx
    86+
    # torch.utils.pytree.cxx.tree_map
    87+
    #
    88+
    _sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
    89+
    return cxx
    90+
    91+
    92+
    if PYTORCH_USE_CXX_PYTREE:
    93+
    cxx = _import_cxx_pytree_and_store() # noqa: F811
    94+
    else:
    95+
    cxx = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]
    96+
    97+
    98+
    _sys.modules[f"{__name__}.python"] = python
    99+
    if cxx is not None:
    100+
    _sys.modules[f"{__name__}.cxx"] = cxx
    101+
    else:
    102+
    del cxx
    103+
    104+
    class LazyCxxModule(_ModuleType):
    105+
    def __getattr__(self, name: str) -> _Any:
    106+
    if name == "__name__":
    107+
    return f"{__name__}.cxx"
    108+
    if name == "__file__":
    109+
    return python.__file__.removesuffix("_python.py") + "_cxx_pytree.py"
    110+
    111+
    cxx = globals().get("cxx")
    112+
    if cxx is None:
    113+
    if name.startswith("_"):
    114+
    raise AttributeError(
    115+
    f"module {self.__name__!r} has not been imported yet: "
    116+
    f"accessing attribute {name!r}. "
    117+
    f"Please import {self.__name__!r} explicitly first."
    118+
    )
    119+
    120+
    # Lazy import on first member access
    121+
    cxx = _import_cxx_pytree_and_store()
    122+
    123+
    return getattr(cxx, name)
    124+
    125+
    def __setattr__(self, name: str, value: _Any) -> None:
    126+
    # Lazy import
    127+
    cxx = _import_cxx_pytree_and_store()
    128+
    return setattr(cxx, name, value)
    129+
    130+
    # This allows the following statements to work properly:
    131+
    #
    132+
    # import torch.utils.pytree.cxx
    133+
    # from torch.utils.pytree.cxx import tree_map
    134+
    #
    135+
    _sys.modules[f"{__name__}.cxx"] = LazyCxxModule(f"{__name__}.cxx")
    77136

    78-
    _sys.modules[f"{__name__}.cxx"] = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]
    137+
    del LazyCxxModule
    79138

    80139

    81140
    if not PYTORCH_USE_CXX_PYTREE:
    @@ -103,8 +162,6 @@
    103162
    tree_unflatten,
    104163
    treespec_pprint,
    105164
    )
    106-
    107-
    PyTreeSpec = _exposed_in(__name__)(PyTreeSpec) # type: ignore[misc]
    108165
    else:
    109166
    from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef]
    110167
    is_namedtuple,
    @@ -132,41 +189,6 @@
    132189
    )
    133190

    134191

    135-
    # Change `__module__` of reexported public APIs to 'torch.utils.pytree'
    136-
    __func_names = frozenset(
    137-
    {
    138-
    "tree_all",
    139-
    "tree_all_only",
    140-
    "tree_any",
    141-
    "tree_any_only",
    142-
    "tree_flatten",
    143-
    "tree_iter",
    144-
    "tree_leaves",
    145-
    "tree_map",
    146-
    "tree_map_",
    147-
    "tree_map_only",
    148-
    "tree_map_only_",
    149-
    "tree_structure",
    150-
    "tree_unflatten",
    151-
    "treespec_pprint",
    152-
    "is_namedtuple",
    153-
    "is_namedtuple_class",
    154-
    "is_namedtuple_instance",
    155-
    "is_structseq",
    156-
    "is_structseq_class",
    157-
    "is_structseq_instance",
    158-
    }
    159-
    )
    160-
    globals().update(
    161-
    {
    162-
    name: _exposed_in(__name__)(member)
    163-
    for name, member in globals().items()
    164-
    if name in __func_names
    165-
    }
    166-
    )
    167-
    del __func_names, _exposed_in
    168-
    169-
    170192
    def register_pytree_node(
    171193
    cls: type[_Any],
    172194
    /,
    @@ -208,9 +230,6 @@ def register_pytree_node(
    208230
    def __getattr__(name: str) -> _Any:
    209231
    if name == "cxx":
    210232
    # Lazy import
    211-
    import torch.utils._cxx_pytree as cxx # noqa: F811
    212-
    213-
    _sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
    214-
    return cxx
    233+
    return _import_cxx_pytree_and_store()
    215234

    216235
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

    torch/utils/pytree/cxx.pyi

    +2-2
    Original file line numberDiff line numberDiff D509 line change
    @@ -1,7 +1,7 @@
    11
    # Owner(s): ["module: pytree"]
    22

    3-
    from .._cxx_pytree import * # noqa: F403
    4-
    from .._cxx_pytree import (
    3+
    from .._cxx_pytree import * # previously public APIs # noqa: F403
    4+
    from .._cxx_pytree import ( # non-public internal APIs
    55
    __all__ as __all__,
    66
    _broadcast_to_and_flatten as _broadcast_to_and_flatten,
    77
    KeyPath as KeyPath,

    torch/utils/pytree/python.pyi

    +2-2
    Original file line numberDiff line numberDiff line change
    @@ -1,7 +1,7 @@
    11
    # Owner(s): ["module: pytree"]
    22

    3-
    from .._pytree import * # noqa: F403
    4-
    from .._pytree import (
    3+
    from .._pytree import * # previously public APIs # noqa: F403
    4+
    from .._pytree import ( # non-public internal APIs
    55
    __all__ as __all__,
    66
    _broadcast_to_and_flatten as _broadcast_to_and_flatten,
    77
    arg_tree_leaves as arg_tree_leaves,

    0 commit comments

    Comments
     (0)
    0