8000 [pytree] support PyStructSequence types for Python pytree · pytorch/pytorch@1b90bf3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1b90bf3

Browse files
committed
[pytree] support PyStructSequence types for Python pytree
ghstack-source-id: d50577d Pull Request resolved: #113258
1 parent 53a2a37 commit 1b90bf3

File tree

6 files changed

+110
-49
lines changed

6 files changed

+110
-49
lines changed

test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_lu_cpu_float32

Whitespace-only changes.

test/test_pytree.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,12 @@ def test_treespec_repr(self):
958958
python_pytree.TreeSpec(dict, [], []),
959959
],
960960
),
961+
# python_pytree.tree_structure(torch.return_types.sort((torch.zeros(1), torch.zeros(1))))
962+
python_pytree.TreeSpec(
963+
python_pytree.structseq,
964+
torch.return_types.sort,
965+
[python_leafspec, python_leafspec],
966+
),
961967
],
962968
)
963969
def test_pytree_serialize(self, spec):
@@ -1471,6 +1477,9 @@ def test_treespec_repr(self):
14711477
cxx_pytree.tree_structure(
14721478
defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})
14731479
),
1480+
cxx_pytree.tree_structure(
1481+
torch.return_types.sort((torch.zeros(1), torch.zeros(1)))
1482+
),
14741483
],
14751484
)
14761485
def test_pytree_serialize(self, spec):

torch/fx/_pytree.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from typing import Any, Callable, Optional, TypeVar
33
from typing_extensions import NamedTuple
44

5-
import torch.return_types
6-
from torch.utils._pytree import PyTree, tree_flatten, TreeSpec
5+
from torch.utils._pytree import PyTree, structseq, tree_flatten, TreeSpec
76

87

98
FlattenFuncSpec = Callable[[PyTree, TreeSpec], list]
@@ -93,21 +92,28 @@ def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
9392
return len(d) == spec.num_children
9493

9594

96-
register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
97-
register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
9895
register_pytree_flatten_spec(
9996
tuple,
10097
_tuple_flatten_spec,
10198
_tuple_flatten_spec_exact_match,
10299
)
103-
for return_type in torch.return_types.all_return_types:
104-
register_pytree_flatten_spec(
105-
return_type,
106-
_tuple_flatten_spec,
107-
_tuple_flatten_spec_exact_match,
108-
)
100+
register_pytree_flatten_spec(
101+
list,
102+
_list_flatten_spec,
103+
_list_flatten_spec_exact_match,
104+
)
105+
register_pytree_flatten_spec(
106+
dict,
107+
_dict_flatten_spec,
108+
_dict_flatten_spec_exact_match,
109+
)
109110
register_pytree_flatten_spec(
110111
namedtuple, # type: ignore[arg-type]
111112
_namedtuple_flatten_spec,
112113
_namedtuple_flatten_spec_exact_match,
113114
)
115+
register_pytree_flatten_spec(
116+
structseq,
117+
_tuple_flatten_spec,
118+
_tuple_flatten_spec_exact_match,
119+
)

torch/return_types.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,53 @@
1-
import inspect
1+
import warnings
2+
from typing_extensions import deprecated
23

3-
import torch
4-
from torch.utils._pytree import register_pytree_node, SequenceKey
4+
from torch._C import _return_types as return_types
55

66

77
__all__ = ["pytree_register_structseq", "all_return_types"]
88< 10000 /td>

9-
all_return_types = []
109

11-
# error: Module has no attribute "_return_types"
12-
return_types = torch._C._return_types # type: ignore[attr-defined]
10+
all_return_types = []
1311

1412

13+
@deprecated(
14+
"torch.return_types.pytree_register_structseq is now a no-op "
15+
"and will be removed in a future release.",
16+
category=FutureWarning,
17+
)
1518
def pytree_register_structseq(cls):
16-
def structseq_flatten(structseq):
17-
return list(structseq), None
18-
19-
def structseq_flatten_with_keys(structseq):
20-
values, context = structseq_flatten(structseq)
21-
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
19+
from torch.utils._pytree import is_structseq_class
2220

23-
def structseq_unflatten(values, context):
24-
return cls(values)
21+
if is_structseq_class(cls):
22+
return
2523

26-
register_pytree_node(
27-
cls,
28-
structseq_flatten,
29-
structseq_unflatten,
30-
flatten_with_keys_fn=structseq_flatten_with_keys,
31-
)
24+
raise TypeError(f"Class {cls!r} is not a PyStructSequence class.")
3225

3326

34-
for name in dir(return_types):
35-
if name.startswith("__"):
27+
_name, _attr = "", None
28+
for _name in dir(return_types):
29+
if _name.startswith("__"):
3630
continue
3731

38-
_attr = getattr(return_types, name)
39-
globals()[name] = _attr
32+
_attr = getattr(return_types, _name)
33+
globals()[_name] = _attr
4034

41-
if not name.startswith("_"):
42-
__all__.append(name)
35+
if not _name.startswith("_"):
36+
__all__.append(_name)
4337
all_return_types.append(_attr)
4438

39+
with warnings.catch_warnings():
40+
warnings.filterwarnings(
41+
"ignore",
42+
category=FutureWarning,
43+
module=__name__,
44+
append=False,
45+
)
4546
# Today everything in torch.return_types is a structseq, aka a "namedtuple"-like
4647
# thing defined by the Python C-API. We're going to need to modify this when that
4748
# is no longer the case.
48-
# NB: I don't know how to check that something is a "structseq" so we do a fuzzy
49-
# check for tuple
50-
if inspect.isclass(_attr) and issubclass(_attr, tuple):
51-
pytree_register_structseq(_attr)
49+
for _attr in all_return_types:
50+
if isinstance(_attr, type) and issubclass(_attr, tuple):
51+
pytree_register_structseq(_attr)
52+
53+
del _name, _attr, warnings, deprecated

torch/utils/_cxx_pytree.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,12 @@ def _private_register_pytree_node(
238238
for the C++ pytree only. End-users should use :func:`register_pytree_node`
239239
instead.
240240
"""
241-
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
242-
# PyStructSequence types
243-
if not optree.is_structseq_class(cls):
244-
optree.register_pytree_node(
245-
cls,
246-
flatten_fn,
247-
_reverse_args(unflatten_fn),
248-
namespace="torch",
249-
)
241+
optree.register_pytree_node(
242+
cls,
243+
flatten_fn,
244+
_reverse_args(unflatten_fn),
245+
namespace="torch",
246+
)
250247

251248

252249
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:

torch/utils/_pytree.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,39 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]:
926926
return deque(values, maxlen=context)
927927

928928

929+
def _structseq_flatten(d: structseq[T]) -> tuple[list[T], Context]:
930+
return list(d), type(d)
931+
932+
933+
def _structseq_flatten_with_keys(
934+
d: structseq[T],
935+
) -> tuple[list[tuple[KeyEntry, T]], Context]:
936+
values, context = _structseq_flatten(d)
937+
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
938+
939+
940+
def _structseq_unflatten(values: Iterable[T], context: Context) -> structseq[T]:
941+
return context(values) # type: ignore[no-any-return]
942+
943+
944+
def _structseq_serialize(context: Context) -> DumpableContext:
945+
json_structseq = {
946+
"class_module": context.__module__,
947+
"class_name": context.__qualname__,
948+
}
949+
return json_structseq
950+
951+
952+
def _structseq_deserialize(dumpable_context: DumpableContext) -> Context:
953+
class_module = dumpable_context["class_module"]
954+
class_name = dumpable_context["class_name"]
955+
assert isinstance(class_module, str)
956+
assert isinstance(class_name, str)
957+
module = importlib.import_module(class_module)
958+
context = getattr(module, class_name)
959+
return context
960+
961+
929962
_private_register_pytree_node(
930963
tuple,
931964
_tuple_flatten,
@@ -979,6 +1012,15 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]:
9791012
serialized_type_name="collections.deque",
9801013
flatten_with_keys_fn=_deque_flatten_with_keys,
9811014
)
1015+
_private_register_pytree_node(
1016+
structseq,
1017+
_structseq_flatten,
1018+
_structseq_unflatten,
1019+
serialized_type_name="structseq",
1020+
to_dumpable_context=_structseq_serialize,
1021+
from_dumpable_context=_structseq_deserialize,
1022+
flatten_with_keys_fn=_structseq_flatten_with_keys,
1023+
)
9821024

9831025

9841026
STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict})
@@ -991,6 +1033,7 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]:
9911033
OrderedDict,
9921034
defaultdict,
9931035
deque,
1036+
structseq,
9941037
},
9951038
)
9961039

@@ -1006,6 +1049,10 @@ def _is_namedtuple_instance(tree: Any) -> bool:
10061049

10071050
def _get_node_type(tree: Any) -> Any:
10081051
node_type = type(tree)
1052+
# Only structseq types that are not explicitly registered should return `structseq`.
1053+
# If a structseq type is explicitly registered, then the actual type will be returned.
1054+
if node_type not in SUPPORTED_NODES and is_structseq_class(node_type):
1055+
return structseq
10091056
# All namedtuple types are implicitly registered as pytree nodes.
10101057
# XXX: Other parts of the codebase expect namedtuple types always return
10111058
# `namedtuple` instead of the actual namedtuple type. Even if the type

0 commit comments

Comments
 (0)
0