|
1 |
| -# mypy: allow-untyped-defs |
2 |
| -from collections import namedtuple |
3 |
| -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type |
| 1 | +from typing import Any, Callable, List, Optional, Type |
| 2 | +from typing_extensions import deprecated |
4 | 3 |
|
5 |
| -import torch.return_types |
| 4 | +import torch.utils._pytree as python_pytree |
6 | 5 | from torch.utils._pytree import PyTree, TreeSpec
|
7 | 6 |
|
8 | 7 |
|
9 | 8 | FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
|
10 | 9 | FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
|
11 | 10 |
|
12 |
| -SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {} |
13 |
| -SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {} |
14 |
| - |
15 | 11 |
|
| 12 | +@deprecated( |
| 13 | + "torch.fx._pytree.register_pytree_flatten_spec is deprecated and it is now a no-op. " |
| 14 | + "Please register the class with `flatten_with_keys` function as pytree node instead.", |
| 15 | + category=FutureWarning, |
| 16 | +) |
16 | 17 | def register_pytree_flatten_spec(
|
17 | 18 | cls: Type[Any],
|
18 | 19 | flatten_fn_spec: FlattenFuncSpec,
|
19 | 20 | flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
|
20 | 21 | ) -> None:
|
21 |
| - SUPPORTED_NODES[cls] = flatten_fn_spec |
22 |
| - SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec |
| 22 | + # no-op, just check if the node is registered and has flatten_with_keys_fn |
| 23 | + handler = python_pytree.SUPPORTED_NODES.get(cls) |
| 24 | + if handler is None: |
| 25 | + raise ValueError( |
| 26 | + f"Unsupported node type {cls}, " |
| 27 | + "please consider registering it as pytree node first." |
| 28 | + ) |
| 29 | + if handler.flatten_with_keys_fn is None: |
| 30 | + raise ValueError( |
| 31 | + f"Unsupported node type {cls}, " |
| 32 | + "please consider registering the pytree node with `flatten_with_keys` function first." |
| 33 | + ) |
23 | 34 |
|
24 | 35 |
|
| 36 | +# The pytree may be wrapped with torch.fx.Proxy, so we cannot use `treespec.flatten_up_to(pytree)`. |
| 37 | +# Use the key path API to index into the pytree instead. |
25 | 38 | def tree_flatten_spec(
|
26 | 39 | pytree: PyTree,
|
27 | 40 | spec: TreeSpec,
|
28 |
| - exact_structural_match=False, |
| 41 | + exact_structural_match: bool = False, |
29 | 42 | ) -> List[Any]:
|
30 |
| - if spec.is_leaf(): |
31 |
| - return [pytree] |
32 |
| - if spec.type not in SUPPORTED_NODES: |
33 |
| - raise RuntimeError( |
34 |
| - f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with " |
35 |
| - "torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make " |
36 |
| - "sure that any custom pytrees have been registered before loading it.", |
37 |
| - ) |
38 |
| - flatten_fn_spec = SUPPORTED_NODES[spec.type] |
39 |
| - child_pytrees = flatten_fn_spec(pytree, spec) |
40 |
| - if exact_structural_match: |
41 |
| - flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type] |
42 |
| - if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec( |
43 |
| - pytree, |
44 |
| - spec, |
45 |
| - ): |
46 |
| - raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}") |
47 |
| - result = [] |
48 |
| - for child, child_spec in zip(child_pytrees, spec.children()): |
49 |
| - flat = tree_flatten_spec(child, child_spec, exact_structural_match) |
50 |
| - result += flat |
51 |
| - return result |
52 |
| - |
53 |
| - |
54 |
| -def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]: |
55 |
| - return [d[k] for k in spec.context] |
56 |
| - |
57 |
| - |
58 |
| -def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]: |
59 |
| - return [d[i] for i in range(spec.num_children)] |
60 |
| - |
61 |
| - |
62 |
| -def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]: |
63 |
| - return [d[i] for i in range(spec.num_children)] |
64 |
| - |
65 |
| - |
66 |
| -def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]: |
67 |
| - return [d[i] for i in range(spec.num_children)] |
68 |
| - |
69 |
| - |
70 |
| -def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool: |
71 |
| - return len(d) == spec.num_children |
72 |
| - |
73 |
| - |
74 |
| -def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool: |
75 |
| - return len(d) == spec.num_children |
76 |
| - |
77 |
| - |
78 |
| -def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool: |
79 |
| - return len(d) == spec.num_children |
80 |
| - |
81 |
| - |
82 |
| -def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: |
83 |
| - return len(d) == spec.num_children |
84 |
| - |
85 |
| - |
86 |
| -register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) |
87 |
| -register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) |
88 |
| -register_pytree_flatten_spec( |
89 |
| - tuple, |
90 |
| - _tuple_flatten_spec, |
91 |
| - _tuple_flatten_spec_exact_match, |
92 |
| -) |
93 |
| -for return_type in torch.return_types.all_return_types: |
94 |
| - register_pytree_flatten_spec( |
95 |
| - return_type, |
96 |
| - _tuple_flatten_spec, |
97 |
| - _tuple_flatten_spec_exact_match, |
98 |
| - ) |
99 |
| -register_pytree_flatten_spec( |
100 |
| - namedtuple, # type: ignore[arg-type] |
101 |
| - _namedtuple_flatten_spec, |
102 |
| - _namedtuple_flatten_spec_exact_match, |
103 |
| -) |
| 43 | + if not isinstance(spec, TreeSpec): |
| 44 | + assert python_pytree._cxx_pytree_exists, "C++ PyTree is not available" |
| 45 | + |
| 46 | + from torch.utils._cxx_pytree import PyTreeSpec |
| 47 | + |
| 48 | + assert isinstance(spec, PyTreeSpec), "Expected a PyTreeSpec" |
| 49 | + return [accessor(pytree) for accessor in spec.accessors()] |
| 50 | + |
| 51 | + # FX `tracer.create_arg(x)` and Dynamo does not support `dummy_leaf = object()` |
| 52 | + # as a sentinel value. Use None here. |
| 53 | + dummy_leaf = None |
| 54 | + dummy_tree = python_pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec) |
| 55 | + return [ |
| 56 | + python_pytree.key_get(pytree, key_path) |
| 57 | + for key_path, _ in python_pytree.tree_leaves_with_path(dummy_tree) |
| 58 | + ] |
0 commit comments