|
4 | 4 |
|
5 | 5 | from __future__ import annotations
|
6 | 6 |
|
7 |
| -from typing import Any, Callable, Iterable, TYPE_CHECKING |
| 7 | +from dataclasses import dataclass, field |
| 8 | +from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING |
| 9 | +from typing_extensions import TypeIs |
8 | 10 |
|
9 | 11 | import torch.utils._pytree as python_pytree
|
| 12 | +from torch.utils._pytree import BUILTIN_TYPES |
10 | 13 |
|
11 | 14 | from ..decorators import substitute_in_graph
|
12 | 15 |
|
13 | 16 |
|
14 | 17 | if TYPE_CHECKING:
|
| 18 | + import builtins |
| 19 | + from typing_extensions import Self |
| 20 | + |
15 | 21 | from torch.utils._cxx_pytree import PyTree
|
16 | 22 |
|
17 | 23 |
|
@@ -95,3 +101,197 @@ def tree_leaves(
|
95 | 101 | return list(tree_iter(tree, is_leaf=is_leaf))
|
96 | 102 |
|
97 | 103 | __all__ += ["tree_leaves"]
|
| 104 | + |
| 105 | + class _Asterisk(str): |
| 106 | + def __new__(cls) -> Self: |
| 107 | + return super().__new__(cls, "*") |
| 108 | + |
| 109 | + def __repr__(self) -> str: |
| 110 | + return "*" # no quotes |
| 111 | + |
| 112 | + _asterisk = _Asterisk() |
| 113 | + del _Asterisk |
| 114 | + |
| 115 | + @dataclass(frozen=True) |
| 116 | + class PyTreeSpec: |
| 117 | + """Analog for :class:`optree.PyTreeSpec` in Python.""" |
| 118 | + |
| 119 | + _children: tuple[PyTreeSpec, ...] |
| 120 | + _type: builtins.type | None |
| 121 | + _metadata: Any |
| 122 | + _entries: tuple[Any, ...] |
| 123 | + _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None |
| 124 | + |
| 125 | + num_nodes: int = field(init=False) |
| 126 | + num_leaves: int = field(init=False) |
| 127 | + num_children: int = field(init=False) |
| 128 | + none_is_leaf: Literal[True] = field(init=False) |
| 129 | + namespace: Literal["torch"] = field(init=False) |
| 130 | + |
| 131 | + def __post_init__(self) -> None: |
| 132 | + if self._type is None: |
| 133 | + assert len(self._children) == 0 |
| 134 | + assert self._metadata is None |
| 135 | + assert self._entries == () |
| 136 | + assert self._unflatten_func is None |
| 137 | + num_nodes = 1 |
| 138 | + num_leaves = 1 |
| 139 | + num_children = 0 |
| 140 | + else: |
| 141 | + assert callable(self._unflatten_func) |
| 142 | + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) |
| 143 | + num_leaves = sum(spec.num_leaves for spec in self._children) |
| 144 | + num_children = len(self._children) |
| 145 | + |
| 146 | + object.__setattr__(self, "num_nodes", num_nodes) |
| 147 | + object.__setattr__(self, "num_leaves", num_leaves) |
| 148 | + object.__setattr__(self, "num_children", num_children) |
| 149 | + object.__setattr__(self, "none_is_leaf", True) |
| 150 | + object.__setattr__(self, "namespace", "torch") |
| 151 | + |
| 152 | + def __repr__(self) -> str: |
| 153 | + def helper(treespec: PyTreeSpec) -> str: |
| 154 | + if treespec.is_leaf(): |
| 155 | + assert treespec.type is None |
| 156 | + return _asterisk |
| 157 | + |
| 158 | + assert treespec.type is not None |
| 159 | + assert callable(treespec._unflatten_func) |
| 160 | + children_representations = [ |
| 161 | + helper(subspec) for subspec in treespec._children |
| 162 | + ] |
| 163 | + if ( |
| 164 | + treespec.type in BUILTIN_TYPES |
| 165 | + or optree.is_namedtuple_class(treespec.type) |
| 166 | + or optree.is_structseq_class(treespec.type) |
| 167 | + ): |
| 168 | + return treespec._unflatten_func( |
| 169 | + treespec._metadata, |
| 170 | + children_representations, |
| 171 | + ) |
| 172 | + return ( |
| 173 | + f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], " |
| 174 | + f"[{', '.join(children_representations)}])" |
| 175 | + ) |
| 176 | + |
| 177 | + return ( |
| 178 | + f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})" |
| 179 | + ) |
| 180 | + |
| 181 | + def __len__(self) -> int: |
| 182 | + return self.num_leaves |
| 183 | + |
| 184 | + @property |
| 185 | + def type(self) -> builtins.type | None: |
| 186 | + return self._type |
| 187 | + |
| 188 | + def is_leaf(self) -> bool: |
| 189 | + return self.num_nodes == 1 and self.num_leaves == 1 |
| 190 | + |
| 191 | + def children(self) -> list[PyTreeSpec]: |
| 192 | + return list(self._children) |
| 193 | + |
| 194 | + def child(self, index: int) -> PyTreeSpec: |
| 195 | + return self._children[index] |
| 196 | + |
| 197 | + def entries(self) -> list[Any]: |
| 198 | + return list(self._entries) |
| 199 | + |
| 200 | + def entry(self, index: int) -> Any: |
| 201 | + return self._entries[index] |
| 202 | + |
| 203 | + def unflatten(self, leaves: Iterable[Any]) -> PyTree: |
| 204 | + if not isinstance(leaves, (list, tuple)): |
| 205 | + leaves = list(leaves) |
| 206 | + if len(leaves) != self.num_leaves: |
| 207 | + raise ValueError( |
| 208 | + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " |
| 209 | + f"but the spec refers to a pytree that holds {self.num_leaves} " |
| 210 | + f"items ({self}).", |
| 211 | + ) |
| 212 | + if self.is_leaf(): |
| 213 | + return leaves[0] |
| 214 | + |
| 215 | + # Recursively unflatten the children |
| 216 | + start = 0 |
| 217 | + end = 0 |
| 218 | + subtrees = [] |
| 219 | + for subspec in self._children: |
| 220 | + end += subspec.num_leaves |
| 221 | + subtrees.append(subspec.unflatten(leaves[start:end])) |
| 222 | + start = end |
| 223 | + |
| 224 | + assert callable(self._unflatten_func) |
| 225 | + return self._unflatten_func(self._metadata, subtrees) |
| 226 | + |
| 227 | + _LEAF_SPEC = PyTreeSpec((), None, None, (), None) |
| 228 | + |
| 229 | + def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: |
| 230 | + return isinstance(obj, PyTreeSpec) |
| 231 | + |
| 232 | + @substitute_in_graph( # type: ignore[arg-type] |
| 233 | + cxx_pytree.tree_flatten, |
| 234 | + # We need to disable constant folding here because we want the function to reference the |
| 235 | + # PyTreeSpec class defined above, not the one in the C++ module. |
| 236 | + can_constant_fold_through=False, |
| 237 | + ) |
| 238 | + def tree_flatten( |
| 239 | + tree: PyTree, |
| 240 | + is_leaf: Callable[[PyTree], bool] | None = None, |
| 241 | + ) -> tuple[list[Any], PyTreeSpec]: |
| 242 | + def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: |
| 243 | + if tree_is_leaf(node, is_leaf=is_leaf): |
| 244 | + leaves.append(node) |
| 245 | + return _LEAF_SPEC |
| 246 | + |
| 247 | + ( |
| 248 | + children, |
| 249 | + metadata, |
| 250 | + entries, |
| 251 | + unflatten_func, |
| 252 | + ) = optree.tree_flatten_one_level( |
| 253 | + node, |
| 254 | + is_leaf=is_leaf, |
| 255 | + none_is_leaf=True, |
| 256 | + namespace="torch", |
| 257 | + ) |
| 258 | + |
| 259 | + # Recursively flatten the children |
| 260 | + subspecs = tuple(helper(child, leaves) for child in children) |
| 261 | + return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type] |
| 262 | + |
| 263 | + leaves: list[Any] = [] |
| 264 | + treespec = helper(tree, leaves) |
| 265 | + return leaves, treespec |
| 266 | + |
| 267 | + __all__ += ["tree_flatten"] |
| 268 | + |
| 269 | + @substitute_in_graph( # type: ignore[arg-type] |
| 270 | + cxx_pytree.tree_structure, |
| 271 | + # We need to disable constant folding here because we want the function to reference the |
| 272 | + # PyTreeSpec class defined above, not the one in the C++ module. |
| 273 | + can_constant_fold_through=False, |
| 274 | + ) |
| 275 | + def tree_structure( |
| 276 | + tree: PyTree, |
| 277 | + is_leaf: Callable[[PyTree], bool] | None = None, |
| 278
10000
td> | + ) -> PyTreeSpec: |
| 279 | + return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value] |
| 280 | + |
| 281 | + __all__ += ["tree_structure"] |
| 282 | + |
| 283 | + @substitute_in_graph( # type: ignore[arg-type] |
| 284 | + cxx_pytree.tree_unflatten, |
| 285 | + # We need to disable constant folding here because we want the function to reference the |
| 286 | + # PyTreeSpec class defined above, not the one in the C++ module. |
| 287 | + can_constant_fold_through=False, |
| 288 | + ) |
| 289 | + def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree: |
| 290 | + if not _is_pytreespec_instance(treespec): |
| 291 | + raise TypeError( |
| 292 | + f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " |
| 293 | + f"PyTreeSpec but got item of type {type(treespec)}." |
| 294 | + ) |
| 295 | + return treespec.unflatten(leaves) |
| 296 | + |
| 297 | + __all__ += ["tree_unflatten"] |
0 commit comments