|
4 | 4 |
|
5 | 5 | from __future__ import annotations
|
6 | 6 |
|
| 7 | +from dataclasses import dataclass, field |
7 | 8 | from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING
|
8 | 9 |
|
9 | 10 | import torch.utils._pytree as python_pytree
|
|
12 | 13 |
|
13 | 14 |
|
14 | 15 | if TYPE_CHECKING:
|
| 16 | + import builtins |
| 17 | + |
15 | 18 | from torch.utils._cxx_pytree import PyTree
|
16 | 19 |
|
17 | 20 |
|
@@ -69,3 +72,132 @@ def tree_leaves(
|
69 | 72 | return list(tree_iter(tree, is_leaf=is_leaf))
|
70 | 73 |
|
71 | 74 | __all__ += ["tree_leaves"]
|
| 75 | + |
| 76 | + @dataclass(frozen=True) |
| 77 | + class PyTreeSpec: |
| 78 | + _children: list[PyTreeSpec] |
| 79 | + _type: builtins.type | None |
| 80 | + _metadata: Any |
| 81 | + _entries: tuple[Any] | None |
| 82 | + _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None |
| 83 | + |
| 84 | + num_nodes: int = field(init=False) |
| 85 | + num_leaves: int = field(init=False) |
| 86 | + num_children: int = field(init=False) |
| 87 | + none_is_leaf: bool = field(init=False) |
| 88 | + namespace: str = field(init=False) |
| 89 | + |
| 90 | + def __post_init__(self) -> None: |
| 91 | + if self._type is None: |
| 92 | + assert len(self._children) == 0 |
| 93 | + assert self._metadata is None |
| 94 | + assert self._entries is None |
| 95 | + assert self._unflatten_func is None |
| 96 | + object.__setattr__(self, "num_nodes", 1) |
| 97 | + object.__setattr__(self, "num_leaves", 1) |
| 98 | + object.__setattr__(self, "num_children", 0) |
| 99 | + else: |
| 100 | + assert callable(self._unflatten_func) |
| 101 | + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) |
| 102 | + num_leaves = sum(spec.num_leaves for spec in self._children) |
| 103 | + num_children = len(self._children) |
| 104 | + object.__setattr__(self, "num_nodes", num_nodes) |
| 105 | + object.__setattr__(self, "num_leaves", num_leaves) |
| 106 | + object.__setattr__(self, "num_children", num_children) |
| 107 | + |
| 108 | + object.__setattr__(self, "none_is_leaf", True) |
| 109 | + object.__setattr__(self, "namespace", "torch") |
| 110 | + |
| 111 | + @property |
| 112 | + def type(self) -> builtins.type | None: |
| 113 | + return self._type |
| 114 | + |
| 115 | + def is_leaf(self) -> bool: |
| 116 | + return self.num_nodes == 1 and self.num_leaves == 1 |
| 117 | + |
| 118 | + def children(self) -> list[PyTreeSpec]: |
| 119 | + return self._children.copy() |
| 120 | + |
| 121 | + def child(self, index: int) -> PyTreeSpec: |
| 122 | + return self._children[index] |
| 123 | + |
| 124 | + def entries(self) -> list[Any]: |
| 125 | + if self._entries is None: |
| 126 | + return list(range(self.num_children)) |
| 127
F438
| + return list(self._entries) |
| 128 | + |
| 129 | + def entry(self, index: int) -> Any: |
| 130 | + return self.entries()[index] |
| 131 | + |
| 132 | + def unflatten(self, leaves: Iterable[Any]) -> PyTree: |
| 133 | + if not isinstance(leaves, (list, tuple)): |
| 134 | + leaves = list(leaves) |
| 135 | + if len(leaves) != self.num_leaves: |
| 136 | + raise ValueError( |
| 137 | + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " |
| 138 | + f"but the spec refers to a pytree that holds {self.num_leaves} " |
| 139 | + f"items ({self}).", |
| 140 | + ) |
| 141 | + if self.is_leaf(): |
| 142 | + return leaves[0] |
| 143 | + |
| 144 | + # Recursively unflatten the children |
| 145 | + start = 0 |
| 146 | + end = 0 |
| 147 | + subtrees = [] |
| 148 | + for subspec in self._children: |
| 149 | + end += subspec.num_leaves |
| 150 | + subtrees.append(subspec.unflatten(leaves[start:end])) |
| 151 | + start = end |
| 152 | + |
| 153 | + assert callable(self._unflatten_func) |
| 154 | + return self._unflatten_func(self._metadata, subtrees) |
| 155 | + |
| 156 | + leafspec = PyTreeSpec([], None, None, None, None) |
| 157 | + |
| 158 | + @substitute_in_graph(cxx_pytree.tree_flatten, can_constant_fold_through=False) # type: ignore[arg-type] |
| 159 | + def tree_flatten( |
| 160 | + tree: PyTree, |
| 161 | + is_leaf: Callable[[PyTree], bool] | None = None, |
| 162 | + ) -> tuple[list[Any], PyTreeSpec]: |
| 163 | + def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: |
| 164 | + if node is None or (is_leaf is not None and is_leaf(node)): |
| 165 | + leaves.append(node) |
| 166 | + return leafspec |
| 167 | + |
| 168 | + node_type = type(node) |
| 169 | + if optree.register_pytree_node.get(node_type, namespace="torch") is None: # type: ignore[attr-defined] |
| 170 | + leaves.append(node) |
| 171 | + return leafspec |
| 172 | + |
| 173 | + ( |
| 174 | + children, |
| 175 | + metadata, |
| 176 | + entries, |
| 177 | + unflatten_func, |
| 178 | + ) = optree.tree_flatten_one_level( |
| 179 | + node, |
| 180 | + is_leaf=is_leaf, |
| 181 | + none_is_leaf=True, |
| 182 | + namespace="torch", |
| 183 | + ) |
| 184 | + |
| 185 | + subspecs = [helper(child, leaves) for child in children] |
| 186 | + return PyTreeSpec(subspecs, node_type, metadata, entries, unflatten_func) # type: ignore[arg-type] |
| 187 | + |
| 188 | + leaves: list[Any] = [] |
| 189 | + treespec = helper(tree, leaves) |
| 190 | + return leaves, treespec |
| 191 | + |
| 192 | + __all__ += ["tree_flatten"] |
| 193 | + |
| 194 | + @substitute_in_graph(cxx_pytree.tree_unflatten, can_constant_fold_through=False) # type: ignore[arg-type] |
| 195 | + def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree: |
| 196 | + if not isinstance(treespec, PyTreeSpec): |
| 197 | + raise TypeError( |
| 198 | + f"tree_unflatten(values, spec): Expected `spec` to be instance of " |
| 199 | + f"TreeSpec but got item of type {type(treespec)}." |
| 200 | + ) |
| 201 | + return treespec.unflatten(leaves) |
| 202 | + |
| 203 | + __all__ += ["tree_unflatten"] |
0 commit comments