|
16 | 16 |
|
17 | 17 | import os as _os
|
18 | 18 | import sys as _sys
|
19 |
| -from typing import Any as _Any, Optional as _Optional |
| 19 | +from types import ModuleType as _ModuleType |
| 20 | +from typing import Any as _Any, Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING |
20 | 21 |
|
21 | 22 | import torch.utils._pytree as python
|
22 |
| -from torch.utils._exposed_in import exposed_in as _exposed_in |
23 | 23 | from torch.utils._pytree import ( # these type aliases are identical in both implementations
|
24 | 24 | FlattenFunc,
|
25 | 25 | FlattenWithKeysFunc,
|
|
30 | 30 | )
|
31 | 31 |
|
32 | 32 |
|
| 33 | +if _TYPE_CHECKING: |
| 34 | + import torch.utils._cxx_pytree as cxx |
| 35 | + |
| 36 | + |
33 | 37 | __all__ = [
|
34 | 38 | "PyTreeSpec",
|
35 | 39 | "register_pytree_node",
|
|
64 | 68 | }
|
65 | 69 |
|
66 | 70 |
|
67 |
| -if PYTORCH_USE_CXX_PYTREE: |
68 |
| - import torch.utils._cxx_pytree as cxx # noqa: F401 |
69 |
| - |
| 71 | +def _import_cxx_pytree_and_store() -> _ModuleType: |
70 | 72 | if not python._cxx_pytree_dynamo_traceable:
|
71 | 73 | raise ImportError(
|
72 | 74 | "Cannot import package `optree`. "
|
73 | 75 | "Please install `optree` via `python -m pip install --upgrade optree`. "
|
74 | 76 | "Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`."
|
75 | 77 | )
|
76 | 78 |
|
| 79 | + import torch.utils._cxx_pytree as cxx |
| 80 | + |
| 81 | + # This allows the following statements to work properly: |
| 82 | + # |
| 83 | + # import torch.utils.pytree |
| 84 | + # |
| 85 | + # torch.utils.pytree.cxx |
| 86 | + # torch.utils.pytree.cxx.tree_map |
| 87 | + # |
| 88 | + _sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx |
| 89 | + return cxx |
| 90 | + |
| 91 | + |
| 92 | +if PYTORCH_USE_CXX_PYTREE: |
| 93 | + cxx = _import_cxx_pytree_and_store() # noqa: F811 |
| 94 | +else: |
| 95 | + cxx = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment] |
| 96 | + |
| 97 | + |
| 98 | +_sys.modules[f"{__name__}.python"] = python |
| 99 | +if cxx is not None: |
| 100 | + _sys.modules[f"{__name__}.cxx"] = cxx |
| 101 | +else: |
| 102 | + del cxx |
| 103 | + |
| 104 | + class LazyCxxModule(_ModuleType): |
| 105 | + def __getattr__(self, name: str) -> _Any: |
| 106 | + if name == "__name__": |
| 107 | + return f"{__name__}.cxx" |
| 108 | + if name == "__file__": |
| 109 | + return python.__file__.removesuffix("_python.py") + "_cxx_pytree.py" |
| 110 | + |
| 111 | + cxx = globals().get("cxx") |
| 112 | + if cxx is None: |
| 113 | + if name.startswith("_"): |
| 114 | + raise AttributeError( |
| 115 | + f"module {self.__name__!r} has not been imported yet: " |
| 116 | + f"accessing attribute {name!r}. " |
| 117 | + f"Please import {self.__name__!r} explicitly first." |
| 118 | + ) |
| 119 | + |
| 120 | + # Lazy import on first member access |
| 121 | + cxx = _import_cxx_pytree_and_store() |
| 122 | + |
| 123 | + return getattr(cxx, name) |
| 124 | + |
| 125 | + def __setattr__(self, name: str, value: _Any) -> None: |
| 126 | + # Lazy import |
| 127 | + cxx = _import_cxx_pytree_and_store() |
| 128 | + return setattr(cxx, name, value) |
| 129 | + |
| 130 | + # This allows the following statements to work properly: |
| 131 | + # |
| 132 | + # import torch.utils.pytree.cxx |
| 133 | + # from torch.utils.pytree.cxx import tree_map |
| 134 | + # |
| 135 | + _sys.modules[f"{__name__}.cxx"] = LazyCxxModule(f"{__name__}.cxx") |
77 | 136 |
|
78 |
| -_sys.modules[f"{__name__}.cxx"] = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment] |
| 137 | + del LazyCxxModule |
79 | 138 |
|
80 | 139 |
|
81 | 140 | if not PYTORCH_USE_CXX_PYTREE:
|
|
103 | 162 | tree_unflatten,
|
104 | 163 | treespec_pprint,
|
105 | 164 | )
|
106 |
| - |
107 |
| - PyTreeSpec = _exposed_in(__name__)(PyTreeSpec) # type: ignore[misc] |
108 | 165 | else:
|
109 | 166 | from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef]
|
110 | 167 | is_namedtuple,
|
|
132 | 189 | )
|
133 | 190 |
|
134 | 191 |
|
135 |
| -# Change `__module__` of reexported public APIs to 'torch.utils.pytree' |
136 |
| -__func_names = frozenset( |
137 |
| - { |
138 |
| - "tree_all", |
139 |
| - "tree_all_only", |
140 |
| - "tree_any", |
141 |
| - "tree_any_only", |
142 |
| - "tree_flatten", |
143 |
| - "tree_iter", |
144 |
| - "tree_leaves", |
145 |
| - "tree_map", |
146 |
| - "tree_map_", |
147 |
| - "tree_map_only", |
148 |
| - "tree_map_only_", |
149 |
| - "tree_structure", |
150 |
| - "tree_unflatten", |
151 |
| - "treespec_pprint", |
152 |
| - "is_namedtuple", |
153 |
| - "is_namedtuple_class", |
154 |
| - "is_namedtuple_instance", |
155 |
| - "is_structseq", |
156 |
| - "is_structseq_class", |
157 |
| - "is_structseq_instance", |
158 |
| - } |
159 |
| -) |
160 |
| -globals().update( |
161 |
| - { |
162 |
| - name: _exposed_in(__name__)(member) |
163 |
| - for name, member in globals().items() |
164 |
| - if name in __func_names |
165 |
| - } |
166 |
| -) |
167 |
| -del __func_names, _exposed_in |
168 |
| - |
169 |
| - |
170 | 192 | def register_pytree_node(
|
171 | 193 | cls: type[_Any],
|
172 | 194 | /,
|
@@ -208,9 +230,6 @@ def register_pytree_node(
|
208 | 230 | def __getattr__(name: str) -> _Any:
|
209 | 231 | if name == "cxx":
|
210 | 232 | # Lazy import
|
211 |
| - import torch.utils._cxx_pytree as cxx # noqa: F811 |
212 |
| - |
213 |
| - _sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx |
214 |
| - return cxx |
| 233 | + return _import_cxx_pytree_and_store() |
215 | 234 |
|
216 | 235 | raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
0 commit comments