6
6
import re
7
7
import subprocess
8
8
import sys
9
+ import time
9
10
import unittest
10
11
from collections import defaultdict , deque , namedtuple , OrderedDict , UserDict
11
12
from dataclasses import dataclass
@@ -731,6 +732,133 @@ def test_pytree_serialize_bad_input(self, pytree_impl):
731
732
with self .assertRaises (TypeError ):
732
733
pytree_impl .treespec_dumps ("random_blurb" )
733
734
735
+ @parametrize (
736
+ "pytree" ,
737
+ [
738
+ subtest (py_pytree , name = "py" ),
739
+ subtest (cxx_pytree , name = "cxx" ),
740
+ ],
741
+ )
742
+ def test_is_namedtuple (self , pytree ):
743
+ DirectNamedTuple1 = namedtuple ("DirectNamedTuple1" , ["x" , "y" ])
744
+
745
+ class DirectNamedTuple2 (NamedTuple ):
746
+ x : int
747
+ y : int
748
+
749
+ class IndirectNamedTuple1 (DirectNamedTuple1 ):
750
+ pass
751
+
752
+ class IndirectNamedTuple2 (DirectNamedTuple2 ):
753
+ pass
754
+
755
+ self .assertTrue (pytree .is_namedtuple (DirectNamedTuple1 (0 , 1 )))
756
+ self .assertTrue (pytree .is_namedtuple (DirectNamedTuple2 (0 , 1 )))
757
+ self .assertTrue (pytree .is_namedtuple (IndirectNamedTuple1 (0 , 1 )))
758
+ self .assertTrue (pytree .is_namedtuple (IndirectNamedTuple2 (0 , 1 )))
759
+ self .assertFalse (pytree .is_namedtuple (time .gmtime ()))
760
+ self .assertFalse (pytree .is_namedtuple ((0 , 1 )))
761
+ self .assertFalse (pytree .is_namedtuple ([0 , 1 ]))
762
+ self .assertFalse (pytree .is_namedtuple ({0 : 1 , 1 : 2 }))
763
+ self .assertFalse (pytree .is_namedtuple ({0 , 1 }))
764
+ self .assertFalse (pytree .is_namedtuple (1 ))
765
+
766
+ self .assertTrue (pytree .is_namedtuple (DirectNamedTuple1 ))
767
+ self .assertTrue (pytree .is_namedtuple (DirectNamedTuple2 ))
768
+ self .assertTrue (pytree .is_namedtuple (IndirectNamedTuple1 ))
769
+ self .assertTrue (pytree .is_namedtuple (IndirectNamedTuple2 ))
770
+ self .assertFalse (pytree .is_namedtuple (time .struct_time ))
771
+ self .assertFalse (pytree .is_namedtuple (tuple ))
772
+ self .assertFalse (pytree .is_namedtuple (list ))
773
+
774
+ self .assertTrue (pytree .is_namedtuple_class (DirectNamedTuple1 ))
775
+ self .assertTrue (pytree .is_namedtuple_class (DirectNamedTuple2 ))
776
+ self .assertTrue (pytree .is_namedtuple_class (IndirectNamedTuple1 ))
777
+ self .assertTrue (pytree .is_namedtuple_class (IndirectNamedTuple2 ))
778
+ self .assertFalse (pytree .is_namedtuple_class (time .struct_time ))
779
+ self .assertFalse (pytree .is_namedtuple_class (tuple ))
780
+ self .assertFalse (pytree .is_namedtuple_class (list ))
781
+
782
+ @parametrize (
783
+ "pytree" ,
784
+ [
785
+ subtest (py_pytree , name = "py" ),
786
+ subtest (cxx_pytree , name = "cxx" ),
787
+ ],
788
+ )
789
+ def test_is_structseq (self , pytree ):
790
+ class FakeStructSeq (tuple ):
791
+ n_fields = 2
792
+ n_sequence_fields = 2
793
+ n_unnamed_fields = 0
794
+
795
+ __slots__ = ()
796
+ __match_args__ = ("x" , "y" )
797
+
798
+ def __new__ (cls , sequence ):
799
+ return super ().__new__ (cls , sequence )
800
+
801
+ @property
802
+ def x (self ):
803
+ return self [0 ]
804
+
805
+ @property
806
+ def y (self ):
807
+ return self [1 ]
808
+
809
+ DirectNamedTuple1 = namedtuple ("DirectNamedTuple1" , ["x" , "y" ])
810
+
811
+ class DirectNamedTuple2 (NamedTuple ):
812
+ x : int
813
+ y : int
814
+
815
+ self .assertFalse (pytree .is_structseq (FakeStructSeq ((0 , 1 ))))
816
+ self .assertTrue (pytree .is_structseq (time .gmtime ()))
817
+ self .assertFalse (pytree .is_structseq (DirectNamedTuple1 (0 , 1 )))
818
+ self .assertFalse (pytree .is_structseq (DirectNamedTuple2 (0 , 1 )))
819
+ self .assertFalse (pytree .is_structseq ((0 , 1 )))
820
+ self .assertFalse (pytree .is_structseq ([0 , 1 ]))
821
+ self .assertFalse (pytree .is_structseq ({0 : 1 , 1 : 2 }))
822
+ self .assertFalse (pytree .is_structseq ({0 , 1 }))
823
+ self .assertFalse (pytree .is_structseq (1 ))
824
+
825
+ self .assertFalse (pytree .is_structseq (FakeStructSeq ))
826
+ self .assertTrue (pytree .is_structseq (time .struct_time ))
827
+ self .assertFalse (pytree .is_structseq (DirectNamedTuple1 ))
828
+ self .assertFalse (pytree .is_structseq (DirectNamedTuple2 ))
829
+ self .assertFalse (pytree .is_structseq (tuple ))
830
+ self .assertFalse (pytree .is_structseq (list ))
831
+
832
+ self .assertFalse (pytree .is_structseq_class (FakeStructSeq ))
833
+ self .assertTrue (
834
+ pytree .is_structseq_class (time .struct_time ),
835
+ )
836
+ self .assertFalse (pytree .is_structseq_class (DirectNamedTuple1 ))
837
+ self .assertFalse (pytree .is_structseq_class (DirectNamedTuple2 ))
838
+ self .assertFalse (pytree .is_structseq_class (tuple ))
839
+ self .assertFalse (pytree .is_structseq_class (list ))
840
+
841
+ # torch.return_types.* are all PyStructSequence types
842
+ for cls in vars (torch .return_types ).values ():
843
+ if isinstance (cls , type ) and issubclass (cls , tuple ):
844
+ self .assertTrue (pytree .is_structseq (cls ))
845
+ self .assertTrue (pytree .is_structseq_class (cls ))
846
+ self .assertFalse (pytree .is_namedtuple (cls ))
847
+ self .assertFalse (pytree .is_namedtuple_class (cls ))
848
+
849
+ inst = cls (range (cls .n_sequence_fields ))
850
+ self .assertTrue (pytree .is_structseq (inst ))
851
+ self .assertTrue (pytree .is_structseq (type (inst )))
852
+ self .assertFalse (pytree .is_structseq_class (inst ))
853
+ self .assertTrue (pytree .is_structseq_class (type (inst )))
854
+ self .assertFalse (pytree .is_namedtuple (inst ))
855
+ self .assertFalse (pytree .is_namedtuple_class (inst ))
856
+ else :
857
+ self .assertFalse (pytree .is_structseq (cls ))
858
+ self .assertFalse (pytree .is_structseq_class (cls ))
859
+ self .assertFalse (pytree .is_namedtuple (cls ))
860
+ self .assertFalse (pytree .is_namedtuple_class (cls ))
861
+
734
862
735
863
class TestPythonPytree (TestCase ):
736
864
def test_deprecated_register_pytree_node (self ):
@@ -975,9 +1103,8 @@ def test_pytree_serialize_namedtuple(self):
975
1103
serialized_type_name = "test_pytree.test_pytree_serialize_namedtuple.Point1" ,
976
1104
)
977
1105
978
- spec = py_pytree .TreeSpec (
979
- namedtuple , Point1 , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
980
- )
1106
+ spec = py_pytree .tree_structure (Point1 (1 , 2 ))
1107
+ self .assertIs (spec .type , namedtuple )
981
1108
roundtrip_spec = py_pytree .treespec_loads (py_pytree .treespec_dumps (spec ))
982
1109
self .assertEqual (spec , roundtrip_spec )
983
1110
@@ -990,18 +1117,28 @@ class Point2(NamedTuple):
990
1117
serialized_type_name = "test_pytree.test_pytree_serialize_namedtuple.Point2" ,
991
1118
)
992
1119
993
- spec = py_pytree .TreeSpec (
994
- namedtuple , Point2 , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1120
+ spec = py_pytree .tree_structure (Point2 (1 , 2 ))
1121
+ self .assertIs (spec .type , namedtuple )
1122
+ roundtrip_spec = py_pytree .treespec_loads (py_pytree .treespec_dumps (spec ))
1123
+ self .assertEqual (spec , roundtrip_spec )
1124
+
1125
+ class Point3 (Point2 ):
1126
+ pass
1127
+
1128
+ py_pytree ._register_namedtuple (
1129
+ Point3 ,
1130
+ serialized_type_name = "test_pytree.test_pytree_serialize_namedtuple.Point3" ,
995
1131
)
1132
+
1133
+ spec = py_pytree .tree_structure (Point3 (1 , 2 ))
1134
+ self .assertIs (spec .type , namedtuple )
996
1135
roundtrip_spec = py_pytree .treespec_loads (py_pytree .treespec_dumps (spec ))
997
1136
self .assertEqual (spec , roundtrip_spec )
998
1137
999
1138
def test_pytree_serialize_namedtuple_bad (self ):
1000
1139
DummyType = namedtuple ("DummyType" , ["x" , "y" ])
1001
1140
1002
- spec = py_pytree .TreeSpec (
1003
- namedtuple , DummyType , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1004
- )
1141
+ spec = py_pytree .tree_structure (DummyType (1 , 2 ))
1005
1142
1006
1143
with self .assertRaisesRegex (
1007
1144
NotImplementedError , "Please register using `_register_namedtuple`"
@@ -1020,9 +1157,7 @@ def __init__(self, x, y):
1020
1157
lambda xs , _ : DummyType (* xs ),
1021
1158
)
1022
1159
1023
- spec = py_pytree .TreeSpec (
1024
- DummyType , None , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1025
- )
1160
+ spec = py_pytree .tree_structure (DummyType (1 , 2 ))
1026
1161
with self .assertRaisesRegex (
1027
1162
NotImplementedError , "No registered serialization name"
1028
1163
):
@@ -1042,9 +1177,7 @@ def __init__(self, x, y):
1042
1177
to_dumpable_context = lambda context : "moo" ,
1043
1178
from_dumpable_context = lambda dumpable_context : None ,
1044
1179
)
1045
- spec = py_pytree .TreeSpec (
1046
- DummyType , None , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1047
- )
1180
+ spec = py_pytree .tree_structure (DummyType (1 , 2 ))
1048
1181
serialized_spec = py_pytree .treespec_dumps (spec , 1 )
1049
1182
self .assertIn ("moo" , serialized_spec )
1050
1183
roundtrip_spec = py_pytree .treespec_loads (serialized_spec )
@@ -1082,9 +1215,7 @@ def __init__(self, x, y):
1082
1215
from_dumpable_context = lambda dumpable_context : None ,
1083
1216
)
1084
1217
1085
- spec = py_pytree .TreeSpec (
1086
- DummyType , None , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1087
- )
1218
+ spec = py_pytree .tree_structure (DummyType (1 , 2 ))
1088
1219
1089
1220
with self .assertRaisesRegex (
1090
1221
TypeError , "Object of type type is not JSON serializable"
@@ -1095,9 +1226,7 @@ def test_pytree_serialize_bad_protocol(self):
1095
1226
import json
1096
1227
1097
1228
Point = namedtuple ("Point" , ["x" , "y" ])
1098
- spec = py_pytree .TreeSpec (
1099
- namedtuple , Point , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1100
- )
1229
+ spec = py_pytree .tree_structure (Point (1 , 2 ))
1101
1230
py_pytree ._register_namedtuple (
1102
1231
Point ,
1103
1232
serialized_type_name = "test_pytree.test_pytree_serialize_bad_protocol.Point" ,
0 commit comments