|
1 |
| -from collections import namedtuple |
2 |
| -from typing import Any, Callable, Optional, TypeVar |
3 |
| -from typing_extensions import NamedTuple |
| 1 | +from typing import Any, Callable, Optional |
| 2 | +from typing_extensions import deprecated |
4 | 3 |
|
5 |
| -import torch.return_types |
| 4 | +import torch.utils._pytree as python_pytree |
6 | 5 | from torch.utils._pytree import PyTree, TreeSpec
|
7 | 6 |
|
8 | 7 |
|
9 | 8 | FlattenFuncSpec = Callable[[PyTree, TreeSpec], list]
|
10 | 9 | FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
|
11 | 10 |
|
12 |
| -SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {} |
13 |
| -SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {} |
14 |
| - |
15 |
| -_T = TypeVar("_T") |
16 |
| -_K = TypeVar("_K") |
17 |
| -_V = TypeVar("_V") |
18 |
| - |
19 | 11 |
|
| 12 | +@deprecated( |
| 13 | + "torch.fx._pytree.register_pytree_flatten_spec is deprecated and it is now a no-op. " |
| 14 | + "Please register the class with `flatten_with_keys` function as pytree node instead.", |
| 15 | + category=FutureWarning, |
| 16 | +) |
20 | 17 | def register_pytree_flatten_spec(
|
21 | 18 | cls: type[Any],
|
22 | 19 | flatten_fn_spec: FlattenFuncSpec,
|
23 | 20 | flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
|
24 | 21 | ) -> None:
|
25 |
| - SUPPORTED_NODES[cls] = flatten_fn_spec |
26 |
| - SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec |
27 |
| - |
28 |
| - |
29 |
| -def _deregister_pytree_flatten_spec( |
30 |
| - cls: type[Any], |
31 |
| -) -> None: |
32 |
| - del SUPPORTED_NODES[cls] |
33 |
| - del SUPPORTED_NODES_EXACT_MATCH[cls] |
| 22 | + # no-op, just check if the node is registered and has flatten_with_keys_fn |
| 23 | + handler = python_pytree.SUPPORTED_NODES.get(cls) |
| 24 | + if handler is None: |
| 25 | + raise ValueError( |
| 26 | + f"Unsupported node type {cls}, " |
| 27 | + "please consider registering it as pytree node first." |
| 28 | + ) |
| 29 | + if handler.flatten_with_keys_fn is None: |
| 30 | + raise ValueError( |
| 31 | + f"Unsupported node type {cls}, " |
| 32 | + "please consider registering the pytree node with `flatten_with_keys` function first." |
| 33 | + ) |
34 | 34 |
|
35 | 35 |
|
| 36 | +# The pytree may be wrapped with torch.fx.Proxy, so we cannot use `treespec.flatten_up_to(pytree)`. |
| 37 | +# Use the key path API to index into the pytree instead. |
36 | 38 | def tree_flatten_spec(
|
37 | 39 | pytree: PyTree,
|
38 | 40 | spec: TreeSpec,
|
39 | 41 | exact_structural_match: bool = False,
|
40 | 42 | ) -> list[Any]:
|
41 |
| - if spec.is_leaf(): |
42 |
| - return [pytree] |
43 |
| - if spec.type not in SUPPORTED_NODES: |
44 |
| - raise RuntimeError( |
45 |
| - f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with " |
46 |
| - "torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make " |
47 |
| - "sure that any custom pytrees have been registered before loading it.", |
48 |
| - ) |
49 |
| - flatten_fn_spec = SUPPORTED_NODES[spec.type] |
50 |
| - child_pytrees = flatten_fn_spec(pytree, spec) |
51 |
| - if exact_structural_match: |
52 |
| - flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type] |
53 |
| - if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec( |
54 |
| - pytree, |
55 |
| - spec, |
56 |
| - ): |
57 |
| - raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}") |
58 |
| - result = [] |
59 |
| - for child, child_spec in zip(child_pytrees, spec.children()): |
60 |
| - flat = tree_flatten_spec(child, child_spec, exact_structural_match) |
61 |
| - result += flat |
62 |
| - return result |
63 |
| - |
64 |
| - |
65 |
| -def _dict_flatten_spec(d: dict[_K, _V], spec: TreeSpec) -> list[_V]: |
66 |
| - return [d[k] for k in spec.context] |
67 |
| - |
68 |
| - |
69 |
| -def _list_flatten_spec(d: list[_T], spec: TreeSpec) -> list[_T]: |
70 |
| - return [d[i] for i in range(spec.num_children)] |
71 |
| - |
72 |
| - |
73 |
| -def _tuple_flatten_spec(d: tuple[_T, ...], spec: TreeSpec) -> list[_T]: |
74 |
| - return [d[i] for i in range(spec.num_children)] |
75 |
| - |
76 |
| - |
77 |
| -def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]: |
78 |
| - return [d[i] for i in range(spec.num_children)] |
79 |
| - |
80 |
| - |
81 |
| -def _dict_flatten_spec_exact_match(d: dict[_K, _V], spec: TreeSpec) -> bool: |
82 |
| - return len(d) == spec.num_children |
83 |
| - |
84 |
| - |
85 |
| -def _list_flatten_spec_exact_match(d: list[_T], spec: TreeSpec) -> bool: |
86 |
| - return len(d) == spec.num_children |
87 |
| - |
88 |
| - |
89 |
| -def _tuple_flatten_spec_exact_match(d: tuple[_T, ...], spec: TreeSpec) -> bool: |
90 |
| - return len(d) == spec.num_children |
91 |
| - |
92 |
| - |
93 |
| -def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: |
94 |
| - return len(d) == spec.num_children |
95 |
| - |
96 |
| - |
97 |
| -register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) |
98 |
| -register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) |
99 |
| -register_pytree_flatten_spec( |
100 |
| - tuple, |
101 |
| - _tuple_flatten_spec, |
102 |
| - _tuple_flatten_spec_exact_match, |
103 |
| -) |
104 |
| -for return_type in torch.return_types.all_return_types: |
105 |
| - register_pytree_flatten_spec( |
106 |
| - return_type, |
107 |
| - _tuple_flatten_spec, |
108 |
| - _tuple_flatten_spec_exact_match, |
109 |
| - ) |
110 |
| -register_pytree_flatten_spec( |
111 |
| - namedtuple, # type: ignore[arg-type] |
112 |
| - _namedtuple_flatten_spec, |
113 |
| - _namedtuple_flatten_spec_exact_match, |
114 |
| -) |
| 43 | + if not isinstance(spec, TreeSpec): |
| 44 | + assert python_pytree._cxx_pytree_exists, "C++ PyTree is not available" |
| 45 | + |
| 46 | + from torch.utils._cxx_pytree import PyTreeSpec |
| 47 | + |
| 48 | + assert isinstance(spec, PyTreeSpec), "Expected a PyTreeSpec" |
| 49 | + return [accessor(pytree) for accessor in spec.accessors()] |
| 50 | + |
| 51 | + # FX `tracer.create_arg(x)` and Dynamo does not support `dummy_leaf = object()` |
| 52 | + # as a sentinel value. Use None here. |
| 53 | + dummy_leaf = None |
| 54 | + dummy_tree = python_pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec) |
| 55 | + return [ |
| 56 | + python_pytree.key_get(pytree, key_path) |
| 57 | + for key_path, _ in python_pytree.tree_leaves_with_path(dummy_tree) |
| 58 | + ] |
0 commit comments