8000 [POC][FX][pytree] cleanup fx pytree implementation · pytorch/pytorch@e6675fb · GitHub
[go: up one dir, main page]

Skip to content

Commit e6675fb

Browse files
committed
[POC][FX][pytree] cleanup fx pytree implementation
ghstack-source-id: 5c10326 Pull Request resolved: #138202
1 parent db4adcd commit e6675fb

File tree

6 files changed

+54
-123
lines changed

6 files changed

+54
-123
lines changed

test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_lu_cpu_float32

Whitespace-only changes.

test/test_fx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from torch.testing._internal.common_methods_invocations import op_db
2828
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
2929
import torch.utils._pytree as pytree
30-
import torch.fx._pytree as fx_pytree
3130
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
3231
from torch.fx.node import Target, Argument, _format_arg
3332
from torch.fx.passes import shape_prop
@@ -3623,8 +3622,11 @@ def f_namedtuple_add(x):
36233622
Foo,
36243623
lambda x: ([x.a, x.b], None),
36253624
lambda x, _: Foo(x[0], x[1]),
3625+
flatten_with_keys_fn=lambda x: (
3626+
((pytree.GetAttrKey("a"), x.a), (pytree.GetAttrKey("b"), x.b)),
3627+
None,
3628+
),
36263629
)
3627-
fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
36283630

36293631
def f_custom(x):
36303632
return x.a + x.b

test/test_proxy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2042,7 +2042,6 @@ def f(t):
20422042
xfail('gather', ''),
20432043
xfail('linalg.pinv', ''),
20442044
xfail('linalg.pinv', 'hermitian'),
2045-
xfail('lu', ''),
20462045
xfail('scatter_add', ''),
20472046
xfail('scatter', ''),
20482047
xfail('take_along_dim', ''),

torch/_export/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@
2828
from torch.export.graph_signature import ExportGraphSignature
2929

3030
from torch.export.graph_signature import CustomObjArgument, InputKind, OutputKind
31-
from torch.fx._pytree import (
32-
_deregister_pytree_flatten_spec,
33-
register_pytree_flatten_spec,
34-
)
3531
from torch.utils._pytree import (
3632
_deregister_pytree_node,
3733
_register_pytree_node,
@@ -1290,17 +1286,6 @@ def from_dumpable_context(dumpable):
12901286
from_dumpable_context=from_dumpable_context,
12911287
)
12921288

1293-
def default_flatten_fn_spec(obj, spec) -> list[Any]:
1294-
flats, context = flatten_fn(obj)
1295-
assert context == spec.context
1296-
return flats
1297-
1298-
register_pytree_flatten_spec(
1299-
cls,
1300-
default_flatten_fn_spec,
1301-
)
1302-
13031289

13041290
def deregister_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
13051291
_deregister_pytree_node(cls)
1306-
_deregister_pytree_flatten_spec(cls)

torch/fx/_pytree.py

Lines changed: 38 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,58 @@
1-
from collections import namedtuple
2-
from typing import Any, Callable, Optional, TypeVar
3-
from typing_extensions import NamedTuple
1+
from typing import Any, Callable, Optional
2+
from typing_extensions import deprecated
43

5-
import torch.return_types
4+
import torch.utils._pytree as python_pytree
65
from torch.utils._pytree import PyTree, TreeSpec
76

87

98
FlattenFuncSpec = Callable[[PyTree, TreeSpec], list]
109
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
1110

12-
SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {}
13-
SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
14-
15-
_T = TypeVar("_T")
16-
_K = TypeVar("_K")
17-
_V = TypeVar("_V")
18-
1911

12+
@deprecated(
13+
"torch.fx._pytree.register_pytree_flatten_spec is deprecated and it is now a no-op. "
14+
F438 "Please register the class with `flatten_with_keys` function as pytree node instead.",
15+
category=FutureWarning,
16+
)
2017
def register_pytree_flatten_spec(
2118
cls: type[Any],
2219
flatten_fn_spec: FlattenFuncSpec,
2320
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
2421
) -> None:
25-
SUPPORTED_NODES[cls] = flatten_fn_spec
26-
SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
27-
28-
29-
def _deregister_pytree_flatten_spec(
30-
cls: type[Any],
31-
) -> None:
32-
del SUPPORTED_NODES[cls]
33-
del SUPPORTED_NODES_EXACT_MATCH[cls]
22+
# no-op, just check if the node is registered and has flatten_with_keys_fn
23+
handler = python_pytree.SUPPORTED_NODES.get(cls)
24+
if handler is None:
25+
raise ValueError(
26+
f"Unsupported node type {cls}, "
27+
"please consider registering it as pytree node first."
28+
)
29+
if handler.flatten_with_keys_fn is None:
30+
raise ValueError(
31+
f"Unsupported node type {cls}, "
32+
"please consider registering the pytree node with `flatten_with_keys` function first."
33+
)
3434

3535

36+
# The pytree may be wrapped with torch.fx.Proxy, so we cannot use `treespec.flatten_up_to(pytree)`.
37+
# Use the key path API to index into the pytree instead.
3638
def tree_flatten_spec(
3739
pytree: PyTree,
3840
spec: TreeSpec,
3941
exact_structural_match: bool = False,
4042
) -> list[Any]:
41-
if spec.is_leaf():
42-
return [pytree]
43-
if spec.type not in SUPPORTED_NODES:
44-
raise RuntimeError(
45-
f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with "
46-
"torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make "
47-
"sure that any custom pytrees have been registered before loading it.",
48-
)
49-
flatten_fn_spec = SUPPORTED_NODES[spec.type]
50-
child_pytrees = flatten_fn_spec(pytree, spec)
51-
if exact_structural_match:
52-
flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type]
53-
if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec(
54-
pytree,
55-
spec,
56-
):
57-
raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}")
58-
result = []
59-
for child, child_spec in zip(child_pytrees, spec.children()):
60-
flat = tree_flatten_spec(child, child_spec, exact_structural_match)
61-
result += flat
62-
return result
63-
64-
65-
def _dict_flatten_spec(d: dict[_K, _V], spec: TreeSpec) -> list[_V]:
66-
return [d[k] for k in spec.context]
67-
68-
69-
def _list_flatten_spec(d: list[_T], spec: TreeSpec) -> list[_T]:
70-
return [d[i] for i in range(spec.num_children)]
71-
72-
73-
def _tuple_flatten_spec(d: tuple[_T, ...], spec: TreeSpec) -> list[_T]:
74-
return [d[i] for i in range(spec.num_children)]
75-
76-
77-
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]:
78-
return [d[i] for i in range(spec.num_children)]
79-
80-
81-
def _dict_flatten_spec_exact_match(d: dict[_K, _V], spec: TreeSpec) -> bool:
82-
return len(d) == spec.num_children
83-
84-
85-
def _list_flatten_spec_exact_match(d: list[_T], spec: TreeSpec) -> bool:
86-
return len(d) == spec.num_children
87-
88-
89-
def _tuple_flatten_spec_exact_match(d: tuple[_T, ...], spec: TreeSpec) -> bool:
90-
return len(d) == spec.num_children
91-
92-
93-
def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
94-
return len(d) == spec.num_children
95-
96-
97-
register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
98-
register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
99-
register_pytree_flatten_spec(
100-
tuple,
101-
_tuple_flatten_spec,
102-
_tuple_flatten_spec_exact_match,
103-
)
104-
for return_type in torch.return_types.all_return_types:
105-
register_pytree_flatten_spec(
106-
return_type,
107-
_tuple_flatten_spec,
108-
_tuple_flatten_spec_exact_match,
109-
)
110-
register_pytree_flatten_spec(
111-
namedtuple, # type: ignore[arg-type]
112-
_namedtuple_flatten_spec,
113-
_namedtuple_flatten_spec_exact_match,
114-
)
43+
if not isinstance(spec, TreeSpec):
44+
assert python_pytree._cxx_pytree_exists, "C++ PyTree is not available"
45+
46+
from torch.utils._cxx_pytree import PyTreeSpec
47+
48+
assert isinstance(spec, PyTreeSpec), "Expected a PyTreeSpec"
49+
return [accessor(pytree) for accessor in spec.accessors()]
50+
51+
# FX `tracer.create_arg(x)` and Dynamo does not support `dummy_leaf = object()`
52+
# as a sentinel value. Use None here.
53+
dummy_leaf = None
54+
dummy_tree = python_pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec)
55+
return [
56+
python_pytree.key_get(pytree, key_path)
57+
for key_path, _ in python_pytree.tree_leaves_with_path(dummy_tree)
58+
]

torch/fx/proxy.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,20 +291,23 @@ def create_arg(self, a: Any) -> Argument:
291291
"""
292292
if isinstance(a, Proxy):
293293
return a.node # most common arg type goes first
294-
elif hasattr(a, "__fx_create_arg__"):
294+
if hasattr(a, "__fx_create_arg__"):
295295
return a.__fx_create_arg__(self)
296+
if a in (None, ...) or isinstance(a, (*base_types, enum.Enum)):
297+
return a
298+
296299
# aggregates
297-
elif isinstance(a, tuple):
300+
if isinstance(a, tuple):
298301
if hasattr(a, "_fields"):
299302
# NamedTuple constructors don't seem to like getting a generator
300303
# expression as an argument to their constructor, so build this
301304
# intermediate tuple and unpack it into the NamedTuple constructor
302305
args = [self.create_arg(elem) for elem in a]
303306
return type(a)(*args) # type: ignore[arg-type]
304307
return type(a)([self.create_arg(elem) for elem in a])
305-
elif isinstance(a, list):
308+
if isinstance(a, list):
306309
return [self.create_arg(elem) for elem in a]
307-
elif isinstance(a, dict):
310+
if isinstance(a, dict):
308311

309312
def no_node(arg):
310313
if isinstance(arg, Node):
@@ -323,33 +326,31 @@ def no_node(arg):
323326

324327
r[k] = self.create_arg(v)
325328
return r
326-
elif isinstance(a, slice):
329+
330+
if isinstance(a, slice):
327331
return slice(
328332
self.create_arg(a.start),
329333
self.create_arg(a.stop),
330334
self.create_arg(a.step),
331335
)
332336

333-
elif isinstance(a, range):
337+
if isinstance(a, range):
334338
return range(
335339
self.crea B268 te_arg(a.start),
336340
self.create_arg(a.stop),
337341
self.create_arg(a.step),
338342
)
339343

340-
elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
344+
if isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
341345
return a
342346

343-
elif is_dataclass(a):
347+
if is_dataclass(a):
344348
kwargs = {
345349
field.name: self.create_arg(getattr(a, field.name))
346350
for field in fields(a)
347351
}
348352
return self.create_node("call_function", a.__class__, (), kwargs)
349353

350-
elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
351-
return a
352-
353354
raise NotImplementedError(f"argument of type: {type(a)}")
354355

355356
@compatibility(is_backward_compatible=True)

0 commit comments

Comments
 (0)
0