diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_lu_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_lu_cpu_float32 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/test_pytree.py b/test/test_pytree.py index bbcbaba43488f..0a03ee1167c38 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -958,6 +958,12 @@ def test_treespec_repr(self): python_pytree.TreeSpec(dict, [], []), ], ), + # python_pytree.tree_structure(torch.return_types.sort((torch.zeros(1), torch.zeros(1)))) + python_pytree.TreeSpec( + python_pytree.structseq, + torch.return_types.sort, + [python_leafspec, python_leafspec], + ), ], ) def test_pytree_serialize(self, spec): @@ -1471,6 +1477,9 @@ def test_treespec_repr(self): cxx_pytree.tree_structure( defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}) ), + cxx_pytree.tree_structure( + torch.return_types.sort((torch.zeros(1), torch.zeros(1))) + ), ], ) def test_pytree_serialize(self, spec): diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index 3ab5ffbea6f5a..6a66e7a16d78d 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -2,8 +2,7 @@ from typing import Any, Callable, Optional, TypeVar from typing_extensions import NamedTuple -import torch.return_types -from torch.utils._pytree import PyTree, tree_flatten, TreeSpec +from torch.utils._pytree import PyTree, structseq, tree_flatten, TreeSpec FlattenFuncSpec = Callable[[PyTree, TreeSpec], list] @@ -93,21 +92,28 @@ def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: return len(d) == spec.num_children -register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) -register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) register_pytree_flatten_spec( tuple, _tuple_flatten_spec, _tuple_flatten_spec_exact_match, ) -for return_type in torch.return_types.all_return_types: - register_pytree_flatten_spec( - return_type, - _tuple_flatten_spec, - _tuple_flatten_spec_exact_match, - ) +register_pytree_flatten_spec( + list, + _list_flatten_spec, + _list_flatten_spec_exact_match, +) +register_pytree_flatten_spec( + dict, + _dict_flatten_spec, + _dict_flatten_spec_exact_match, +) register_pytree_flatten_spec( namedtuple, # type: ignore[arg-type] _namedtuple_flatten_spec, _namedtuple_flatten_spec_exact_match, ) +register_pytree_flatten_spec( + structseq, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, +) diff --git a/torch/return_types.py b/torch/return_types.py index d456742be4b88..121e499135669 100644 --- a/torch/return_types.py +++ b/torch/return_types.py @@ -1,51 +1,53 @@ -import inspect +import warnings +from typing_extensions import deprecated -import torch -from torch.utils._pytree import register_pytree_node, SequenceKey +from torch._C import _return_types as return_types __all__ = ["pytree_register_structseq", "all_return_types"] -all_return_types = [] -# error: Module has no attribute "_return_types" -return_types = torch._C._return_types # type: ignore[attr-defined] +all_return_types = [] +@deprecated( + "torch.return_types.pytree_register_structseq is now a no-op " + "and will be removed in a future release.", + category=FutureWarning, +) def pytree_register_structseq(cls): - def structseq_flatten(structseq): - return list(structseq), None - - def structseq_flatten_with_keys(structseq): - values, context = structseq_flatten(structseq) - return [(SequenceKey(i), v) for i, v in enumerate(values)], context + from torch.utils._pytree import is_structseq_class - def structseq_unflatten(values, context): - return cls(values) + if is_structseq_class(cls): + return - register_pytree_node( - cls, - structseq_flatten, - structseq_unflatten, - flatten_with_keys_fn=structseq_flatten_with_keys, - ) + raise TypeError(f"Class {cls!r} is not a PyStructSequence class.") -for name in dir(return_types): - if name.startswith("__"): +_name, _attr = "", None +for _name in dir(return_types): + if _name.startswith("__"): continue - _attr = getattr(return_types, name) - globals()[name] = _attr + _attr = getattr(return_types, _name) + globals()[_name] = _attr - if not name.startswith("_"): - __all__.append(name) + if not _name.startswith("_"): + __all__.append(_name) all_return_types.append(_attr) +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=FutureWarning, + module=__name__, + append=False, + ) # Today everything in torch.return_types is a structseq, aka a "namedtuple"-like # thing defined by the Python C-API. We're going to need to modify this when that # is no longer the case. - # NB: I don't know how to check that something is a "structseq" so we do a fuzzy - # check for tuple - if inspect.isclass(_attr) and issubclass(_attr, tuple): - pytree_register_structseq(_attr) + for _attr in all_return_types: + if isinstance(_attr, type) and issubclass(_attr, tuple): + pytree_register_structseq(_attr) + +del _name, _attr, warnings, deprecated diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 24c73061b716a..923a0b42bebda 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -253,15 +253,12 @@ def _private_register_pytree_node( for the C++ pytree only. End-users should use :func:`register_pytree_node` instead. """ - # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support - # PyStructSequence types - if not optree.is_structseq_class(cls): - optree.register_pytree_node( - cls, - flatten_fn, - _reverse_args(unflatten_fn), - namespace="torch", - ) + optree.register_pytree_node( + cls, + flatten_fn, + _reverse_args(unflatten_fn), + namespace="torch", + ) def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]: diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index ce08f3bef4045..d7edd9010fa6f 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -927,6 +927,39 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]: return deque(values, maxlen=context) +def _structseq_flatten(d: structseq[T]) -> tuple[list[T], Context]: + return list(d), type(d) + + +def _structseq_flatten_with_keys( + d: structseq[T], +) -> tuple[list[tuple[KeyEntry, T]], Context]: + values, context = _structseq_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _structseq_unflatten(values: Iterable[T], context: Context) -> structseq[T]: + return context(values) # type: ignore[no-any-return] + + +def _structseq_serialize(context: Context) -> DumpableContext: + json_structseq = { + "class_module": context.__module__, + "class_name": context.__qualname__, + } + return json_structseq + + +def _structseq_deserialize(dumpable_context: DumpableContext) -> Context: + class_module = dumpable_context["class_module"] + class_name = dumpable_context["class_name"] + assert isinstance(class_module, str) + assert isinstance(class_name, str) + module = importlib.import_module(class_module) + context = getattr(module, class_name) + return context + + _private_register_pytree_node( tuple, _tuple_flatten, @@ -980,6 +1013,15 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]: serialized_type_name="collections.deque", flatten_with_keys_fn=_deque_flatten_with_keys, ) +_private_register_pytree_node( + structseq, + _structseq_flatten, + _structseq_unflatten, + serialized_type_name="structseq", + to_dumpable_context=_structseq_serialize, + from_dumpable_context=_structseq_deserialize, + flatten_with_keys_fn=_structseq_flatten_with_keys, +) STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict}) @@ -992,6 +1034,7 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]: OrderedDict, defaultdict, deque, + structseq, }, ) @@ -1007,6 +1050,10 @@ def _is_namedtuple_instance(tree: Any) -> bool: def _get_node_type(tree: Any) -> Any: node_type = type(tree) + # Only structseq types that are not explicitly registered should return `structseq`. + # If a structseq type is explicitly registered, then the actual type will be returned. + if node_type not in SUPPORTED_NODES and is_structseq_class(node_type): + return structseq # All namedtuple types are implicitly registered as pytree nodes. # XXX: Other parts of the codebase expect namedtuple types always return # `namedtuple` instead of the actual namedtuple type. Even if the type