@@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
65
65
A = auto ()
66
66
67
67
68
- python_leafspec = python_pytree .LeafSpec ()
69
-
70
-
71
68
class TestGenericPytree (TestCase ):
72
69
def test_aligned_public_apis (self ):
73
70
public_apis = python_pytree .__all__
@@ -197,7 +194,7 @@ def test_flatten_unflatten_leaf(self, pytree):
197
194
def run_test_with_leaf (leaf ):
198
195
values , treespec = pytree .tree_flatten (leaf )
199
196
self .assertEqual (values , [leaf ])
200
- self .assertEqual (treespec , pytree .LeafSpec ())
197
+ self .assertEqual (treespec , pytree .treespec_leaf ())
201
198
202
199
unflattened = pytree .tree_unflatten (values , treespec )
203
200
self .assertEqual (unflattened , leaf )
@@ -215,7 +212,7 @@ def run_test_with_leaf(leaf):
215
212
(
216
213
python_pytree ,
217
214
lambda tup : python_pytree .TreeSpec (
218
- tuple , None , [python_leafspec for _ in tup ]
215
+ tuple , None , [python_pytree . treespec_leaf () for _ in tup ]
219
216
),
220
217
),
221
218
name = "python" ,
@@ -250,7 +247,7 @@ def run_test(tup):
250
247
(
251
248
python_pytree ,
252
249
lambda lst : python_pytree .TreeSpec (
253
- list , None , [python_leafspec for _ in lst ]
250
+ list , None , [python_pytree . treespec_leaf () for _ in lst ]
254
251
),
255
252
),
256
253
name = "python" ,
@@ -286,7 +283,7 @@ def run_test(lst):
286
283
lambda dct : python_pytree .TreeSpec (
287
284
dict ,
288
285
list (dct .keys ()),
289
- [python_leafspec for _ in dct .values ()],
286
+ [python_pytree . treespec_leaf () for _ in dct .values ()],
290
287
),
291
288
),
292
289
name = "python" ,
@@ -327,7 +324,7 @@ def run_test(dct):
327
324
lambda odict : python_pytree .TreeSpec (
328
325
OrderedDict ,
329
326
list (odict .keys ()),
330
- [python_leafspec for _ in odict .values ()],
327
+ [python_pytree . treespec_leaf () for _ in odict .values ()],
331
328
),
332
329
),
333
330
name = "python" ,
@@ -371,7 +368,7 @@ def run_test(odict):
371
368
lambda ddct : python_pytree .TreeSpec (
372
369
defaultdict ,
373
370
[ddct .default_factory , list (ddct .keys ())],
374
- [python_leafspec for _ in ddct .values ()],
371
+ [python_pytree . treespec_leaf () for _ in ddct .values ()],
375
372
),
376
373
),
377
374
name = "python" ,
@@ -413,7 +410,7 @@ def run_test(ddct):
413
410
(
414
411
python_pytree ,
415
412
lambda deq : python_pytree .TreeSpec (
416
- deque , deq .maxlen , [python_leafspec for _ in deq ]
413
+ deque , deq .maxlen , [python_pytree . treespec_leaf () for _ in deq ]
417
414
),
418
415
),
419
416
name = "python" ,
@@ -453,7 +450,7 @@ def test_flatten_unflatten_namedtuple(self, pytree):
453
450
def run_test (tup ):
454
451
if pytree is python_pytree :
455
452
expected_spec = python_pytree .TreeSpec (
456
- namedtuple , Point , [python_leafspec for _ in tup ]
453
+ namedtuple , Point , [python_pytree . treespec_leaf () for _ in tup ]
457
454
)
458
455
else :
459
456
expected_spec = cxx_pytree .tree_structure (Point (0 , 1 ))
@@ -839,16 +836,16 @@ def test_import_pytree_doesnt_import_optree(self):
839
836
840
837
def test_treespec_equality (self ):
841
838
self .assertEqual (
842
- python_pytree .LeafSpec (),
843
- python_pytree .LeafSpec (),
839
+ python_pytree .treespec_leaf (),
840
+ python_pytree .treespec_leaf (),
844
841
)
845
842
self .assertEqual (
846
843
python_pytree .TreeSpec (list , None , []),
847
844
python_pytree .TreeSpec (list , None , []),
848
845
)
849
846
self .assertEqual (
850
- python_pytree .TreeSpec (list , None , [python_pytree .LeafSpec ()]),
851
- python_pytree .TreeSpec (list , None , [python_pytree .LeafSpec ()]),
847
+ python_pytree .TreeSpec (list , None , [python_pytree .treespec_leaf ()]),
848
+ python_pytree .TreeSpec (list , None , [python_pytree .treespec_leaf ()]),
852
849
)
853
850
self .assertFalse (
854
851
python_pytree .TreeSpec (tuple , None , [])
@@ -883,24 +880,32 @@ def test_treespec_repr(self):
883
880
# python_pytree.tree_structure({})
884
881
python_pytree .TreeSpec (dict , [], []),
885
882
# python_pytree.tree_structure([0])
886
- python_pytree .TreeSpec (list , None , [python_leafspec ]),
883
+ python_pytree .TreeSpec (list , None , [python_pytree . treespec_leaf () ]),
887
884
# python_pytree.tree_structure([0, 1])
888
885
python_pytree .TreeSpec (
889
886
list ,
890
887
None ,
891
- [python_leafspec , python_leafspec ],
888
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
892
889
),
893
890
# python_pytree.tree_structure((0, 1, 2))
894
891
python_pytree .TreeSpec (
895
892
tuple ,
896
893
None ,
897
- [python_leafspec , python_leafspec , python_leafspec ],
894
+ [
895
+ python_pytree .treespec_leaf (),
896
+ python_pytree .treespec_leaf (),
897
+ python_pytree .treespec_leaf (),
898
+ ],
898
899
),
899
900
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
900
901
python_pytree .TreeSpec (
901
902
dict ,
902
903
["a" , "b" , "c" ],
903
- [python_leafspec , python_leafspec , python_leafspec ],
904
+ [
905
+ python_pytree .treespec_leaf (),
906
+ python_pytree .treespec_leaf (),
907
+ python_pytree .treespec_leaf (),
908
+ ],
904
909
),
905
910
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
906
911
python_pytree .TreeSpec (
@@ -910,13 +915,17 @@ def test_treespec_repr(self):
910
915
python_pytree .TreeSpec (
911
916
tuple ,
912
917
None ,
913
- [python_leafspec , python_leafspec ],
918
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
914
919
),
915
- python_leafspec ,
920
+ python_pytree . treespec_leaf () ,
916
921
python_pytree .TreeSpec (
917
922
dict ,
918
923
["a" , "b" , "c" ],
919
- [python_leafspec , python_leafspec , python_leafspec ],
924
+ [
925
+ python_pytree .treespec_leaf (),
926
+ python_pytree .treespec_leaf (),
927
+ python_pytree .treespec_leaf (),
928
+ ],
920
929
),
921
930
],
922
931
),
@@ -929,12 +938,15 @@ def test_treespec_repr(self):
929
938
tuple ,
930
939
None ,
931
940
[
932
- python_leafspec ,
933
- python_leafspec ,
941
+ python_pytree . treespec_leaf () ,
942
+ python_pytree . treespec_leaf () ,
934
943
python_pytree .TreeSpec (
935
944
list ,
936
945
None ,
937
- [python_leafspec , python_leafspec ],
946
+ [
947
+ python_pytree .treespec_leaf (),
948
+ python_pytree .treespec_leaf (),
949
+ ],
938
950
),
939
951
],
940
952
),
@@ -948,12 +960,12 @@ def test_treespec_repr(self):
948
960
python_pytree .TreeSpec (
949
961
list ,
950
962
None ,
951
- [python_leafspec , python_leafspec ],
963
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
952
964
),
953
965
python_pytree .TreeSpec (
954
966
list ,
955
967
None ,
956
- [python_leafspec , python_leafspec ],
968
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
957
969
),
958
970
python_pytree .TreeSpec (dict , [], []),
959
971
],
@@ -962,7 +974,7 @@ def test_treespec_repr(self):
962
974
python_pytree .TreeSpec (
963
975
python_pytree .structseq ,
964
976
torch .return_types .sort ,
965
- [python_leafspec , python_leafspec ],
977
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
966
978
),
967
979
],
968
980
)
@@ -988,7 +1000,7 @@ def test_pytree_serialize_defaultdict_enum(self):
988
1000
list ,
989
1001
None ,
990
1002
[
991
- python_leafspec ,
1003
+ python_pytree . treespec_leaf () ,
992
1004
],
993
1005
),
994
1006
],
@@ -997,7 +1009,7 @@ def test_pytree_serialize_defaultdict_enum(self):
997
1009
self .assertIsInstance (serialized_spec , str )
998
1010
999
1011
def test_pytree_serialize_enum (self ):
1000
- spec = python_pytree .TreeSpec (dict , TestEnum .A , [python_leafspec ])
1012
+ spec = python_pytree .TreeSpec (dict , TestEnum .A , [python_pytree . treespec_leaf () ])
1001
1013
1002
1014
serialized_spec = python_pytree .treespec_dumps (spec )
1003
1015
self .assertIsInstance (serialized_spec , str )
@@ -1160,12 +1172,20 @@ def test_saved_serialized(self):
1160
1172
OrderedDict ,
1161
1173
[1 , 2 , 3 ],
1162
1174
[
1163
- python_pytree .TreeSpec (tuple , None , [python_leafspec , python_leafspec ]),
1164
- python_leafspec ,
1175
+ python_pytree .TreeSpec (
1176
+ tuple ,
1177
+ None ,
1178
+ [python_pytree .treespec_leaf (), python_pytree .treespec_leaf ()],
1179
+ ),
1180
+ python_pytree .treespec_leaf (),
1165
1181
python_pytree .TreeSpec (
1166
1182
dict ,
1167
1183
[4 , 5 , 6 ],
1168
- [python_leafspec , python_leafspec , python_leafspec ],
1184
+ [
1185
+ python_pytree .treespec_leaf (),
1186
+ python_pytree .treespec_leaf (),
1187
+ python_pytree .treespec_leaf (),
1188
+ ],
1169
1189
),
1170
1190
],
1171
1191
)
@@ -1450,7 +1470,7 @@ def setUp(self):
1450
1470
raise unittest .SkipTest ("C++ pytree tests are not supported in fbcode" )
1451
1471
1452
1472
def test_treespec_equality (self ):
1453
- self .assertEqual (cxx_pytree .LeafSpec (), cxx_pytree .LeafSpec ())
1473
+ self .assertEqual (cxx_pytree .treespec_leaf (), cxx_pytree .treespec_leaf ())
1454
1474
1455
1475
def test_treespec_repr (self ):
1456
1476
# Check that it looks sane
0 commit comments