8000 [pytree] support PyStructSequence types for Python pytree by XuehaiPan · Pull Request #113254 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[pytree] support PyStructSequence types for Python pytree #113254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions torch/autograd/forward_ad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from collections import namedtuple

from typing import Any
from typing import Any, Optional

import torch
from .grad_mode import _DecoratorContextManager
Expand Down Expand Up @@ -118,16 +117,32 @@ def make_dual(tensor, tangent, *, level=None):
f"Expected tangent to be floating point or complex, but got: {tangent.dtype}"
)

return torch._VF._make_dual(tensor, tangent, level=level)
return torch._VF._make_dual(tensor, tangent, level=level) # type: ignore[attr-defined]


_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])


class UnpackedDualTensor(_UnpackedDualTensor):
class UnpackedDualTensor(tuple):
r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
See :func:`unpack_dual` for more details."""
pass

def __new__(
cls,
primal: torch.Tensor,
tangent: Optional[torch.Tensor], # type: ignore[arg-type]
) -> "UnpackedDualTensor":
return super().__new__(cls, (primal, tangent)) # type: ignore[arg-type]

@property
def primal(self) -> torch.Tensor:
return self[0]

@property
def tangent(self) -> Optional[torch.Tensor]:
return self[1]

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(primal={self.primal}, tangent={self.tangent})"
)


def unpack_dual(tensor, *, level=None):
Expand Down Expand Up @@ -156,7 +171,7 @@ def unpack_dual(tensor, *, level=None):
if level < 0:
return UnpackedDualTensor(tensor, None)

primal, dual = torch._VF._unpack_dual(tensor, level=level)
primal, dual = torch._VF._unpack_dual(tensor, level=level) # type: ignore[attr-defined]

return UnpackedDualTensor(primal, dual)

Expand Down
6 changes: 6 additions & 0 deletions torch/return_types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import torch
import inspect
import warnings

__all__ = ["pytree_register_structseq"]

# error: Module has no attribute "_return_types"
return_types = torch._C._return_types # type: ignore[attr-defined]

def pytree_register_structseq(cls):
if torch.utils._pytree.is_structseq_class(cls):
return

warnings.warn(f"Class {cls!r} is not a PyStructSequence class.")

def structseq_flatten(structseq):
return list(structseq), None

Expand Down
8 changes: 8 additions & 0 deletions torch/utils/_pytree/__init__.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
DumpableContext,
FlattenFunc,
FromDumpableContextFn,
is_namedtuple,
is_namedtuple_class,
is_structseq,
is_structseq_class,
LeafSpec,
PyTree,
register_pytree_node,
Expand Down Expand Up @@ -86,4 +90,8 @@
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_structseq",
"is_structseq_class",
]
8 changes: 8 additions & 0 deletions torch/utils/_pytree/api/__init__.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
"""

from .python import (
is_namedtuple,
is_namedtuple_class,
is_structseq,
is_structseq_class,
LeafSpec,
register_pytree_node,
tree_all,
Expand Down Expand Up @@ -69,4 +73,8 @@
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_structseq",
"is_structseq_class",
]
27 changes: 17 additions & 10 deletions torch/utils/_pytree/api/cxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
raise ImportError("C++ pytree utilities do not work with torch::deploy.")

import optree
from optree import PyTreeSpec # direct import for type annotations
from optree import (
is_namedtuple,
is_namedtuple_class,
is_structseq,
is_structseq_class,
PyTreeSpec, # direct import for type annotations
)

from .typing import (
Context,
Expand Down Expand Up @@ -64,6 +70,10 @@
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_structseq",
"is_structseq_class",
]


Expand Down Expand Up @@ -242,15 +252,12 @@ def _private_register_pytree_node(
for the C++ pytree only. End-users should use :func:`register_pytree_node`
instead.
"""
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
# PyStructSequence types
if not optree.is_structseq_class(cls):
optree.register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
namespace=namespace,
)
optree.register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
namespace=namespace,
)


def tree_flatten(
Expand Down
138 changes: 122 additions & 16 deletions torch/utils/_pytree/api/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,23 @@
DefaultDict,
Deque,
Dict,
Final,
Generic,
Iterable,
List,
NamedTuple,
NoReturn,
Optional,
OrderedDict as GenericOrderedDict,
overload,
Tuple,
Type,
TypeVar,
Union,
)

from typing_extensions import Self # Python 3.11+

from .typing import (
Context,
DumpableContext,
Expand Down Expand Up @@ -79,6 +85,10 @@
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_structseq",
"is_structseq_class",
]


Expand Down Expand Up @@ -236,6 +246,76 @@ def _private_register_pytree_node(
register_pytree_node = _register_pytree_node


# Reference: https://github.com/metaopt/optree/blob/v0.10.0/optree/typing.py
def is_namedtuple(obj: Union[object, type]) -> bool:
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
cls = obj if isinstance(obj, type) else type(obj)
return is_namedtuple_class(cls)


# Reference: https://github.com/metaopt/optree/blob/v0.10.0/optree/typing.py
def is_namedtuple_class(cls: type) -> bool:
"""Return whether the class is a subclass of namedtuple."""
return (
isinstance(cls, type)
and issubclass(cls, tuple)
and isinstance(getattr(cls, "_fields", None), tuple)
and all(isinstance(field, str) for field in cls._fields) # type: ignore[attr-defined]
)


_T_co = TypeVar("_T_co", covariant=True)


# Reference: https://github.com/metaopt/optree/blob/v0.10.0/optree/typing.py
class structseq(tuple, Generic[_T_co]): # type: ignore[misc]
"""A generic type stub for CPython's ``PyStructSequence`` type."""

n_fields: Final[int] # type: ignore[misc]
n_sequence_fields: Final[int] # type: ignore[misc]
n_unnamed_fields: Final[int] # type: ignore[misc]

def __init_subclass__(cls) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError("type 'structseq' is not an acceptable ba F438 se type")

def __new__(
cls: Type[Self],
sequence: Iterable[_T_co],
dict: Dict[str, Any] = ...,
) -> Self:
raise NotImplementedError


# Reference: https://github.com/metaopt/optree/blob/v0.10.0/optree/typing.py
def is_structseq(obj: Union[object, type]) -> bool:
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
cls = obj if isinstance(obj, type) else type(obj)
return is_structseq_class(cls)


# Reference: https://github.com/metaopt/optree/blob/v0.10.0/optree/typing.py
def is_structseq_class(cls: type) -> bool:
"""Return whether the class is a class of PyStructSequence."""
if (
isinstance(cls, type)
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
and cls.__base__ is tuple
# Check PyStructSequence members
and isinstance(getattr(cls, "n_sequence_fields", None), int)
and isinstance(getattr(cls, "n_fields", None), int)
and isinstance(getattr(cls, "n_unnamed_fields", None), int)
):
try:
# Check the type does not allow subclassing
class SubClass(cls):
pass

except TypeError:
return True
return False


def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
return list(d.values()), list(d.keys())

Expand Down Expand Up @@ -339,6 +419,32 @@ def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]:
return deque(values, maxlen=context)


def _structseq_flatten(seq: structseq[Any]) -> Tuple[List[Any], Context]:
return list(seq), type(seq)


def _structseq_unflatten(values: Iterable[Any], context: Context) -> structseq[Any]:
return context(values)


def _structseq_serialize(context: Context) -> DumpableContext:
json_structseq = {
"class_module": context.__module__,
"class_name": context.__qualname__,
}
return json_structseq


def _structseq_deserialize(dumpable_context: DumpableContext) -> Context:
class_module = dumpable_context["class_module"]
class_name = dumpable_context["class_name"]
assert isinstance(class_module, str)
assert isinstance(class_name, str)
module = importlib.import_module(class_module)
context = getattr(module, class_name)
return context


_private_register_pytree_node(
tuple,
_tuple_flatten,
Expand All @@ -361,9 +467,9 @@ def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]:
namedtuple,
_namedtuple_flatten,
_namedtuple_unflatten,
serialized_type_name="collections.namedtuple",
to_dumpable_context=_namedtuple_serialize,
from_dumpable_context=_namedtuple_deserialize,
serialized_type_name="collections.namedtuple",
)
_private_register_pytree_node(
OrderedDict,
Expand All @@ -385,24 +491,24 @@ def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]:
_deque_unflatten,
serialized_type_name="collections.deque",
)


# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
def _is_namedtuple_instance(tree: Any) -> bool:
typ = type(tree)
bases = typ.__bases__
if len(bases) != 1 or bases[0] != tuple:
return False 8081
fields = getattr(typ, "_fields", None)
if not isinstance(fields, tuple):
return False
return all(type(entry) == str for entry in fields)
_private_register_pytree_node(
structseq,
_structseq_flatten,
_structseq_unflatten,
serialized_type_name="structseq",
to_dumpable_context=_structseq_serialize,
from_dumpable_context=_structseq_deserialize,
)


def _get_node_type(tree: Any) -> Any:
if _is_namedtuple_instance(tree):
return namedtuple
return type(tree)
node_type = type(tree)
if node_type not in SUPPORTED_NODES:
if is_namedtuple_class(node_type):
return namedtuple
if is_structseq_class(node_type):
return structseq
return node_type


# A leaf is defined as anything that is not a Node.
Expand Down
0