8000 [pytree] add APIs to determine a class is a namedtuple or PyStructSeq… · amathewc/pytorch@1437cab · GitHub
[go: up one dir, main page]

Skip to content

Commit 1437cab

Browse files
XuehaiPanamathewc
authored andcommitted
[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (pytorch#113257)
Changes in this PR: 1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence. 2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types. 3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class. Resolves pytorch#75982. New tests are included in this PR. - pytorch#75982 Pull Request resolved: pytorch#113257 Approved by: https://github.com/zou3519
1 parent 8103e5e commit 1437cab

File tree

8 files changed

+345
-57
lines changed

8 files changed

+345
-57
lines changed

benchmarks/dynamo/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,7 @@ def load(cls, model, example_inputs):
13971397
# see https://github.com/pytorch/pytorch/issues/113029
13981398
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
13991399

1400-
if pytree._is_namedtuple_instance(example_outputs):
1400+
if pytree.is_namedtuple_instance(example_outputs):
14011401
typ = type(example_outputs)
14021402
pytree._register_namedtuple(
14031403
typ,

test/test_pytree.py

Lines changed: 149 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77
import subprocess
88
import sys
9+
import time
910
import unittest
1011
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
1112
from dataclasses import dataclass
@@ -731,6 +732,133 @@ def test_pytree_serialize_bad_input(self, pytree_impl):
731732
with self.assertRaises(TypeError):
732733
pytree_impl.treespec_dumps("random_blurb")
733734

735+
@parametrize(
736+
"pytree",
737+
[
738+
subtest(py_pytree, name="py"),
739+
subtest(cxx_pytree, name="cxx"),
740+
],
741+
)
742+
def test_is_namedtuple(self, pytree):
743+
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
744+
745+
class DirectNamedTuple2(NamedTuple):
746+
x: int
747+
y: int
748+
749+
class IndirectNamedTuple1(DirectNamedTuple1):
750+
pass
751+
752+
class IndirectNamedTuple2(DirectNamedTuple2):
753+
pass
754+
755+
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1)))
756+
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1)))
757+
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1)))
758+
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1)))
759+
self.assertFalse(pytree.is_namedtuple(time.gmtime()))
760+
self.assertFalse(pytree.is_namedtuple((0, 1)))
761+
self.assertFalse(pytree.is_namedtuple([0, 1]))
762+
self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2}))
763+
self.assertFalse(pytree.is_namedtuple({0, 1}))
764+
self.assertFalse(pytree.is_namedtuple(1))
765+
766+
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1))
767+
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2))
768+
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1))
769+
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2))
770+
self.assertFalse(pytree.is_namedtuple(time.struct_time))
771+
self.assertFalse(pytree.is_namedtuple(tuple))
772+
self.assertFalse(pytree.is_namedtuple(list))
773+
774+
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1))
775+
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2))
776+
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1))
777+
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2))
778+
self.assertFalse(pytree.is_namedtuple_class(time.struct_time))
779+
self.assertFalse(pytree.is_namedtuple_class(tuple))
780+
self.assertFalse(pytree.is_namedtuple_class(list))
781+
782+
@parametrize(
783+
"pytree",
784+
[
785+
subtest(py_pytree, name="py"),
786+
subtest(cxx_pytree, name="cxx"),
787+
],
788+
)
789+
def test_is_structseq(self, pytree):
790+
class FakeStructSeq(tuple):
791+
n_fields = 2
792+
n_sequence_fields = 2
793+
n_unnamed_fields = 0
794+
795+
__slots__ = ()
796+
__match_args__ = ("x", "y")
797+
798+
def __new__(cls, sequence):
799+
return super().__new__(cls, sequence)
800+
801+
@property
802+
def x(self):
803+
return self[0]
804+
805+
@property
806+
def y(self):
807+
return self[1]
808+
809+
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
810+
811+
class DirectNamedTuple2(NamedTuple):
812+
x: int
813+
y: int
814+
815+
self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1))))
816+
self.assertTrue(pytree.is_structseq(time.gmtime()))
817+
self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1)))
818+
self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1)))
819+
self.assertFalse(pytree.is_structseq((0, 1)))
820+
self.assertFalse(pytree.is_structseq([0, 1]))
821+
self.assertFalse(pytree.is_structseq({0: 1, 1: 2}))
822+
self.assertFalse(pytree.is_structseq({0, 1}))
823+
self.assertFalse(pytree.is_structseq(1))
824+
825+
self.assertFalse(pytree.is_structseq(FakeStructSeq))
826+
self.assertTrue(pytree.is_structseq(time.struct_time))
827+
self.assertFalse(pytree.is_structseq(DirectNamedTuple1))
828+
self.assertFalse(pytree.is_structseq(DirectNamedTuple2))
829+
self.assertFalse(pytree.is_structseq(tuple))
830+
self.assertFalse(pytree.is_structseq(list))
831+
832+
self.assertFalse(pytree.is_structseq_class(FakeStructSeq))
833+
self.assertTrue(
834+
pytree.is_structseq_class(time.struct_time),
835+
)
836+
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1))
837+
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2))
838+
self.assertFalse(pytree.is_structseq_class(tuple))
839+
self.assertFalse(pytree.is_structseq_class(list))
840+
841+
# torch.return_types.* are all PyStructSequence types
842+
for cls in vars(torch.return_types).values():
843+
if isinstance(cls, type) and issubclass(cls, tuple):
844+
self.assertTrue(pytree.is_structseq(cls))
845+
self.assertTrue(pytree.is_structseq_class(cls))
846+
self.assertFalse(pytree.is_namedtuple(cls))
847+
self.assertFalse(pytree.is_namedtuple_class(cls))
848+
849+
inst = cls(range(cls.n_sequence_fields))
850+
self.assertTrue(pytree.is_structseq(inst))
851+
self.assertTrue(pytree.is_structseq(type(inst)))
852+
self.assertFalse(pytree.is_structseq_class(inst))
853+
self.assertTrue(pytree.is_structseq_class(type(inst)))
854+
self.assertFalse(pytree.is_namedtuple(inst))
855+
self.assertFalse(pytree.is_namedtuple_class(inst))
856+
else:
857+
self.assertFalse(pytree.is_structseq(cls))
858+
self.assertFalse(pytree.is_structseq_class(cls))
859+
self.assertFalse(pytree.is_namedtuple(cls))
860+
self.assertFalse(pytree.is_namedtuple_class(cls))
861+
734862

735863
class TestPythonPytree(TestCase):
736864
def test_deprecated_register_pytree_node(self):
@@ -975,9 +1103,8 @@ def test_pytree_serialize_namedtuple(self):
9751103
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
9761104
)
9771105

978-
spec = py_pytree.TreeSpec(
979-
namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
980-
)
1106+
spec = py_pytree.tree_structure(Point1(1, 2))
1107+
self.assertIs(spec.type, namedtuple)
9811108
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
9821109
self.assertEqual(spec, roundtrip_spec)
9831110

@@ -990,18 +1117,28 @@ class Point2(NamedTuple):
9901117
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
9911118
)
9921119

993-
spec = py_pytree.TreeSpec(
994-
namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1120+
spec = py_pytree.tree_structure(Point2(1, 2))
1121+
self.assertIs(spec.type, namedtuple)
1122+
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
1123+
self.assertEqual(spec, roundtrip_spec)
1124+
1125+
class Point3(Point2):
1126+
pass
1127+
1128+
py_pytree._register_namedtuple(
1129+
Point3,
1130+
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3",
9951131
)
1132+
1133+
spec = py_pytree.tree_structure(Point3(1, 2))
1134+
self.assertIs(spec.type, namedtuple)
9961135
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
9971136
self.assertEqual(spec, roundtrip_spec)
9981137

9991138
def test_pytree_serialize_namedtuple_bad(self):
10001139
DummyType = namedtuple("DummyType", ["x", "y"])
10011140

1002-
spec = py_pytree.TreeSpec(
1003-
namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1004-
)
1141+
spec = py_pytree.tree_structure(DummyType(1, 2))
10051142

10061143
with self.assertRaisesRegex(
10071144
NotImplementedError, "Please register using `_register_namedtuple`"
@@ -1020,9 +1157,7 @@ def __init__(self, x, y):
10201157
lambda xs, _: DummyType(*xs),
10211158
)
10221159

1023-
spec = py_pytree.TreeSpec(
1024-
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1025-
)
1160+
spec = py_pytree.tree_structure(DummyType(1, 2))
10261161
with self.assertRaisesRegex(
10271162
NotImplementedError, "No registered serialization name"
10281163
):
@@ -1042,9 +1177,7 @@ def __init__(self, x, y):
10421177
to_dumpable_context=lambda context: "moo",
10431178
from_dumpable_context=lambda dumpable_context: None,
10441179
)
1045-
spec = py_pytree.TreeSpec(
1046-
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1047-
)
1180+
spec = py_pytree.tree_structure(DummyType(1, 2))
10481181
serialized_spec = py_pytree.treespec_dumps(spec, 1)
10491182
self.assertIn("moo", serialized_spec)
10501183
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
@@ -1082,9 +1215,7 @@ def __init__(self, x, y):
10821215
from_dumpable_context=lambda dumpable_context: None,
10831216
)
10841217

1085-
spec = py_pytree.TreeSpec(
1086-
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1087-
)
1218+
spec = py_pytree.tree_structure(DummyType(1, 2))
10881219

10891220
with self.assertRaisesRegex(
10901221
TypeError, "Object of type type is not JSON serializable"
@@ -1095,9 +1226,7 @@ def test_pytree_serialize_bad_protocol(self):
10951226
import json
10961227

10971228
Point = namedtuple("Point", ["x", "y"])
1098-
spec = py_pytree.TreeSpec(
1099-
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1100-
)
1229+
spec = py_pytree.tree_structure(Point(1, 2))
11011230
py_pytree._register_namedtuple(
11021231
Point,
11031232
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",

torch/_dynamo/polyfills/pytree.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ def _(*args: Any, **kwargs: Any) -> bool:
5656
"structseq_fields",
5757
):
5858
__func = getattr(optree, __name)
59-
substitute_in_graph(__func, can_constant_fold_through=True)(
59+
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
6060
__func.__python_implementation__
6161
)
62+
__all__ += [__name] # noqa: PLE0604
6263
del __func
6364
del __name
6465

torch/_export/serde/serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ def serialize_treespec(self, treespec):
12431243
def store_namedtuple_fields(ts):
12441244
if ts.type is None:
12451245
return
1246-
if ts.type == namedtuple:
1246+
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
12471247
serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ts.context].serialized_type_name
12481248
if serialized_type_name in self.treespec_namedtuple_fields:
12491249
field_names = self.treespec_namedtuple_fields[serialized_type_name].field_names

torch/autograd/forward_ad.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# mypy: allow-untyped-defs
22
import os
3-
from collections import namedtuple
4-
from typing import Any
3+
from typing import Any, NamedTuple, Optional
54

65
import torch
76

@@ -129,16 +128,15 @@ def make_dual(tensor, tangent, *, level=None):
129128
return torch._VF._make_dual(tensor, tangent, level=level)
130129

131130

132-
_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])
133-
134-
135-
class UnpackedDualTensor(_UnpackedDualTensor):
131+
class UnpackedDualTensor(NamedTuple):
136132
r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
137133
138134
See :func:`unpack_dual` for more details.
139-
140135
"""
141136

137+
primal: torch.Tensor
138+
tangent: Optional[torch.Tensor]
139+
142140

143141
def unpack_dual(tensor, *, level=None):
144142
r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient.

torch/testing/_internal/composite_compliance.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,16 @@ def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs):
552552

553553
expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
554554
expected = tree_map(fwAD.unpack_dual, expected)
555-
expected_primals = tree_map(lambda x: x.primal, expected)
556-
expected_tangents = tree_map(lambda x: x.tangent, expected)
555+
expected_primals = tree_map(
556+
lambda x: x.primal,
557+
expected,
558+
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
559+
)
560+
expected_tangents = tree_map(
561 10000 +
lambda x: x.tangent,
562+
expected,
563+
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
564+
)
557565

558566
# Permutations of arg and kwargs in CCT.
559567
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
@@ -586,7 +594,15 @@ def unwrap(e):
586594
return e.elem if isinstance(e, CCT) else e
587595

588596
actual = tree_map(fwAD.unpack_dual, actual)
589-
actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
590-
actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
597+
actual_primals = tree_map(
598+
lambda x: unwrap(x.primal),
599+
actual,
600+
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
601+
)
602+
actual_tangents = tree_map(
603+
lambda x: unwrap(x.tangent),
604+
actual,
605+
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
606+
)
591607
assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
592608
assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)

torch/utils/_cxx_pytree.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
2424

2525
import torch.utils._pytree as python_pytree
26-
from torch.utils._pytree import KeyEntry as KeyEntry
26+
from torch.utils._pytree import (
27+
is_namedtuple as is_namedtuple,
28+
is_namedtuple_class as is_namedtuple_class,
29+
is_namedtuple_instance as is_namedtuple_instance,
30+
is_structseq as is_structseq,
31+
is_structseq_class as is_structseq_class,
32+
is_structseq_instance as is_structseq_instance,
33+
KeyEntry as KeyEntry,
34+
)
2735

2836

2937
__all__ = [
@@ -39,6 +47,7 @@
3947
"keystr",
4048
"key_get",
4149
"register_pytree_node",
50+
"tree_is_leaf",
4251
"tree_flatten",
4352
"tree_flatten_with_path",
4453
"tree_unflatten",
@@ -58,6 +67,12 @@
5867
"treespec_dumps",
5968
"treespec_loads",
6069
"treespec_pprint",
70+
"is_namedtuple",
71+
"is_namedtuple_class",
72+
"is_namedtuple_instance",
73+
"is_structseq",
74+
"is_structseq_class",
75+
"is_structseq_instance",
6176
]
6277

6378

0 commit comments

Comments
 (0)
0