8000 [POC][FX][pytree] cleanup fx pytree implementation · pytorch/pytorch@0ab8666 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0ab8666

Browse files
committed
[POC][FX][pytree] cleanup fx pytree implementation
ghstack-source-id: fbbcb4a Pull Request resolved: #138202
1 parent 729b1a4 commit 0ab8666

File tree

5 files changed

+55
-98
lines changed

5 files changed

+55
-98
lines changed

test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_lu_cpu_float32

Whitespace-only changes.

test/test_fx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from torch.testing._internal.common_methods_invocations import op_db
2727
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
2828
import torch.utils._pytree as pytree
29-
import torch.fx._pytree as fx_pytree
3029
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
3130
from torch.fx.node import Target, Argument, _format_arg
3231
from torch.fx.passes import shape_prop
@@ -3605,8 +3604,11 @@ def f_namedtuple_add(x):
36053604
Foo,
36063605
lambda x: ([x.a, x.b], None),
36073606
lambda x, _: Foo(x[0], x[1]),
3607+
flatten_with_keys_fn=lambda x: (
3608+
((pytree.GetAttrKey("a"), x.a), (pytree.GetAttrKey("b"), x.b)),
3609+
None,
3610+
),
36083611
)
3609-
fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
36103612

36113613
def f_custom(x):
36123614
return x.a + x.b

test/test_proxy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2051,7 +2051,6 @@ def f(t):
20512051
xfail('gather', ''),
20522052
xfail('linalg.pinv', ''),
20532053
xfail('linalg.pinv', 'hermitian'),
2054-
xfail('lu', ''),
20552054
xfail('scatter_add', ''),
20562055
xfail('scatter', ''),
20572056
xfail('take_along_dim', ''),

torch/fx/_pytree.py

Lines changed: 39 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,58 @@
1-
# mypy: allow-untyped-defs
2-
from collections import namedtuple
3-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
1+
from typing import Any, Callable, List, Optional, Type
2+
from typing_extensions import deprecated
43

5-
import torch.return_types
4+
import torch.utils._pytree as python_pytree
65
from torch.utils._pytree import PyTree, TreeSpec
76

87

98
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
109
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
1110

12-
SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {}
13-
SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
14-
1511

12+
@deprecated(
13+
"torch.fx._pytree.register_pytree_flatten_spec is deprecated and it is now a no-op. "
14+
"Please register the class with `flatten_with_keys` function as pytree node instead.",
15+
category=FutureWarning,
16+
)
1617
def register_pytree_flatten_spec(
1718
cls: Type[Any],
1819
flatten_fn_spec: FlattenFuncSpec,
1920
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
2021
) -> None:
21-
SUPPORTED_NODES[cls] = flatten_fn_spec
22-
SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
22+
# no-op, just check if the node is registered and has flatten_with_keys_fn
23+
handler = python_pytree.SUPPORTED_NODES.get(cls)
24+
if handler is None:
25+
raise ValueError(
26+
f"Unsupported node type {cls}, "
27+
"please consider registering it as pytree node first."
28+
)
29+
if handler.flatten_with_keys_fn is None:
30+
raise ValueError(
31+
f"Unsupported node type {cls}, "
32+
"please consider registering the pytree node with `flatten_with_keys` function first."
33+
)
2334

2435

36+
# The pytree may be wrapped with torch.fx.Proxy, so we cannot use `treespec.flatten_up_to(pytree)`.
37+
# Use the key path API to index into the pytree instead.
2538
def tree_flatten_spec(
2639
pytree: PyTree,
2740
spec: TreeSpec,
28-
exact_structural_match=False,
41+
exact_structural_match: bool = False,
2942
) -> List[Any]:
30-
if spec.is_leaf():
31-
return [pytree]
32-
if spec.type not in SUPPORTED_NODES:
33-
raise RuntimeError(
34-
f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with "
35-
"torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make "
36-
"sure that any custom pytrees have been registered before loading it.",
37-
)
38-
flatten_fn_spec = SUPPORTED_NODES[spec.type]
39-
child_pytrees = flatten_fn_spec(pytree, spec)
40-
if exact_structural_match:
41-
flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type]
42-
if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec(
43-
pytree,
44-
spec,
45-
):
46-
raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}")
47-
result = []
48-
for child, child_spec in zip(child_pytrees, spec.children()):
49-
flat = tree_flatten_spec(child, child_spec, exact_structural_match)
50-
result += flat
51-
return result
52-
53-
54-
def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
55-
return [d[k] for k in spec.context]
56-
57-
58-
def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
59-
return [d[i] for i in range(spec.num_children)]
60-
61-
62-
def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
63-
return [d[i] for i in range(spec.num_children)]
64-
65-
66-
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
67-
return [d[i] for i in range(spec.num_children)]
68-
69-
70-
def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool:
71-
return len(d) == spec.num_children
72-
73-
74-
def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool:
75-
return len(d) == spec.num_children
76-
77-
78-
def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool:
79-
return len(d) == spec.num_children
80-
81-
82-
def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
83-
return len(d) == spec.num_children
84-
85-
86-
register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
87-
register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
88-
register_pytree_flatten_spec(
89-
tuple,
90-
_tuple_flatten_spec,
91-
_tuple_flatten_spec_exact_match,
92-
)
93-
for return_type in torch.return_types.all_return_types:
94-
register_pytree_flatten_spec(
95-
return_type,
96-
_tuple_flatten_spec,
97-
_tuple_flatten_spec_exact_match,
98-
)
99-
register_pytree_flatten_spec(
100-
namedtuple, # type: ignore[arg-type]
101-
_namedtuple_flatten_spec,
102-
_namedtuple_flatten_spec_exact_match,
103-
)
43+
if not isinstance(spec, TreeSpec):
44+
assert python_pytree._cxx_pytree_exists, "C++ PyTree is not available"
45+
46+
from torch.utils._cxx_pytree import PyTreeSpec
47+
48+
assert isinstance(spec, PyTreeSpec), "Expected a PyTreeSpec"
49+
return [accessor(pytree) for accessor in spec.accessors()]
50+
51+
# FX `tracer.create_arg(x)` and Dynamo does not support `dummy_leaf = object()`
52+
# as a sentinel value. Use None here.
53+
dummy_leaf = None
54+
dummy_tree = python_pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec)
55+
return [
56+
python_pytree.key_get(pytree, key_path)
57+
for key_path, _ in python_pytree.tree_leaves_with_path(dummy_tree)
58+
]

torch/fx/proxy.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -289,20 +289,23 @@ def create_arg(self, a: Any) -> Argument:
289289
"""
290290
if isinstance(a, Proxy):
291291
return a.node # most common arg type goes first
292-
elif hasattr(a, "__fx_create_arg__"):
292+
if hasattr(a, "__fx_create_arg__"):
293293
return a.__fx_create_arg__(self)
294+
if a in (None, ...) or isinstance(a, (*base_types, enum.Enum)):
295+
return a
296+
294297
# aggregates
295-
elif isinstance(a, tuple):
298+
if isinstance(a, tuple):
296299
if hasattr(a, "_fields"):
297300
# NamedTuple constructors don't seem to like getting a generator
298301
# expression as an argument to their constructor, so build this
299302
# intermediate tuple and unpack it into the NamedTuple constructor
300303
args = [self.create_arg(elem) for elem in a]
301304
return type(a)(*args) # type: ignore[arg-type]
302305
return type(a)([self.create_arg(elem) for elem in a])
303-
elif isinstance(a, list):
306+
if isinstance(a, list):
304307
return [self.create_arg(elem) for elem in a]
305-
elif isinstance(a, dict):
308+
if isinstance(a, dict):
306309

307310
def no_node(arg):
308311
if isinstance(arg, Node):
@@ -321,33 +324,31 @@ def no_node(arg):
321324

322325
r[k] = self.create_arg(v)
323326
return r
324-
elif isinstance(a, slice):
327+
328+
if isinstance(a, slice):
325329
return slice(
326330
self.create_arg(a.start),
327331
self.create_arg(a.stop),
328332
self.create_arg(a.step),
329333
)
330334

331-
elif isinstance(a, range):
335+
if isinstance(a, range):
332336
return range(
333337
self.create_arg(a.start),
334338
self.create_arg(a.stop),
335339
self.create_arg(a.step),
336340
)
337341

338-
elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
342+
if isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
339343
return a
340344

341-
elif is_dataclass(a):
345+
if is_dataclass(a):
342346
kwargs = {
343347
field.name: self.create_arg(getattr(a, field.name))
344348
for field in fields(a)
345349
}
346350
return self.create_node("call_function", a.__class__, (), kwargs)
347351

348-
elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
349-
return a
350-
351352
raise NotImplementedError(f"argument of type: {type(a)}")
352353

353354
@compatibility(is_backward_compatible=True)

0 commit comments

Comments
 (0)
0