@@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
65
65
A = auto ()
66
66
67
67
68
- python_leafspec = python_pytree .LeafSpec ()
69
-
70
-
71
68
class TestGenericPytree (TestCase ):
72
69
def test_aligned_public_apis (self ):
73
70
public_apis = python_pytree .__all__
@@ -197,7 +194,7 @@ def test_flatten_unflatten_leaf(self, pytree):
197
194
def run_test_with_leaf (leaf ):
198
195
values , treespec = pytree .tree_flatten (leaf )
199
196
self .assertEqual (values , [leaf ])
200
- self .assertEqual (treespec , pytree .LeafSpec ())
197
+ self .assertEqual (treespec , pytree .treespec_leaf ())
201
198
202
199
unflattened = pytree .tree_unflatten (values , treespec )
203
200
self .assertEqual (unflattened , leaf )
@@ -215,7 +212,7 @@ def run_test_with_leaf(leaf):
215
212
(
216
213
python_pytree ,
217
214
lambda tup : python_pytree .TreeSpec (
218
- tuple , None , [python_leafspec for _ in tup ]
215
+ tuple , None , [python_pytree . treespec_leaf () for _ in tup ]
219
216
),
220
217
),
221
218
name = "python" ,
@@ -254,7 +251,7 @@ def run_test(tup):
254
251
(
255
252
python_pytree ,
256
253
lambda lst : python_pytree .TreeSpec (
257
- list , None , [python_leafspec for _ in lst ]
254
+ list , None , [python_pytree . treespec_leaf () for _ in lst ]
258
255
),
259
256
),
260
257
name = "python" ,
@@ -294,7 +291,7 @@ def run_test(lst):
294
291
lambda dct : python_pytree .TreeSpec (
295
292
dict ,
296
293
list (dct .keys ()),
297
- [python_leafspec for _ in dct .values ()],
294
+ [python_pytree . treespec_leaf () for _ in dct .values ()],
298
295
),
299
296
),
300
297
name = "python" ,
@@ -342,7 +339,7 @@ def run_test(dct):
342
339
lambda odict : python_pytree .TreeSpec (
343
340
OrderedDict ,
344
341
list (odict .keys ()),
345
- [python_leafspec for _ in odict .values ()],
342
+ [python_pytree . treespec_leaf () for _ in odict .values ()],
346
343
),
347
344
),
348
345
name = "python" ,
@@ -393,7 +390,7 @@ def run_test(odict):
393
390
lambda ddct : python_pytree .TreeSpec (
394
391
defaultdict ,
395
392
[ddct .default_factory , list (ddct .keys ())],
396
- [python_leafspec for _ in ddct .values ()],
393
+ [python_pytree . treespec_leaf () for _ in ddct .values ()],
397
394
),
398
395
),
399
396
name = "python" ,
@@ -444,7 +441,7 @@ def run_test(ddct):
444
441
(
445
442
python_pytree ,
446
443
lambda deq : python_pytree .TreeSpec (
447
- deque , deq .maxlen , [python_leafspec for _ in deq ]
444
+ deque , deq .maxlen , [python_pytree . treespec_leaf () for _ in deq ]
448
445
),
449
446
),
450
447
name = "python" ,
@@ -491,7 +488,7 @@ def test_flatten_unflatten_namedtuple(self, pytree):
491
488
def run_test (tup ):
492
489
if pytree is python_pytree :
493
490
expected_spec = python_pytree .TreeSpec (
494
- namedtuple , Point , [python_leafspec for _ in tup ]
491
+ namedtuple , Point , [python_pytree . treespec_leaf () for _ in tup ]
495
492
)
496
493
else :
497
494
expected_spec = cxx_pytree .tree_structure (Point (0 , 1 ))
@@ -877,16 +874,16 @@ def test_import_pytree_doesnt_import_optree(self):
877
874
878
875
def test_treespec_equality (self ):
879
876
self .assertEqual (
880
- python_pytree .LeafSpec (),
881
- python_pytree .LeafSpec (),
877
+ python_pytree .treespec_leaf (),
878
+ python_pytree .treespec_leaf (),
882
879
)
883
880
self .assertEqual (
884
881
python_pytree .TreeSpec (list , None , []),
885
882
python_pytree .TreeSpec (list , None , []),
886
883
)
887
884
self .assertEqual (
888
- python_pytree .TreeSpec (list , None , [python_pytree .LeafSpec ()]),
889
- python_pytree .TreeSpec (list , None , [python_pytree .LeafSpec ()]),
885
+ python_pytree .TreeSpec (list , None , [python_pytree .treespec_leaf ()]),
886
+ python_pytree .TreeSpec (list , None , [python_pytree .treespec_leaf ()]),
890
887
)
891
888
self .assertFalse (
892
889
python_pytree .TreeSpec (tuple , None , [])
@@ -921,24 +918,32 @@ def test_treespec_repr(self):
921
918
# python_pytree.tree_structure({})
922
919
python_pytree .TreeSpec (dict , [], []),
923
920
# python_pytree.tree_structure([0])
924
- python_pytree .TreeSpec (list , None , [python_leafspec ]),
921
+ python_pytree .TreeSpec (list , None , [python_pytree . treespec_leaf () ]),
925
922
# python_pytree.tree_structure([0, 1])
926
923
python_pytree .TreeSpec (
927
924
list ,
928
925
None ,
929
- [python_leafspec , python_leafspec ],
926
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
930
927
),
931
928
# python_pytree.tree_structure((0, 1, 2))
932
929
python_pytree .TreeSpec (
933
930
tuple ,
934
931
None ,
935
- [python_leafspec , python_leafspec , python_leafspec ],
932
+ [
933
+ python_pytree .treespec_leaf (),
934
+ python_pytree .treespec_leaf (),
935
+ python_pytree .treespec_leaf (),
936
+ ],
936
937
),
937
938
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
938
939
python_pytree .TreeSpec (
939
940
dict ,
940
941
["a" , "b" , "c" ],
941
- [python_leafspec , python_leafspec , python_leafspec ],
942
+ [
943
+ python_pytree .treespec_leaf (),
944
+ python_pytree .treespec_leaf (),
945
+ python_pytree .treespec_leaf (),
946
+ ],
942
947
),
943
948
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
944
949
python_pytree .TreeSpec (
@@ -948,13 +953,17 @@ def test_treespec_repr(self):
948
953
python_pytree .TreeSpec (
949
954
tuple ,
950
955
None ,
951
- [python_leafspec , python_leafspec ],
956
+ [python_pytree . treespec_leaf (), python_pytree . treespec_l
10000
eaf () ],
952
957
),
953
- python_leafspec ,
958
+ python_pytree . treespec_leaf () ,
954
959
python_pytree .TreeSpec (
955
960
dict ,
956
961
["a" , "b" , "c" ],
957
- [python_leafspec , python_leafspec , python_leafspec ],
962
+ [
963
+ python_pytree .treespec_leaf (),
964
+ python_pytree .treespec_leaf (),
965
+ python_pytree .treespec_leaf (),
966
+ ],
958
967
),
959
968
],
960
969
),
@@ -967,12 +976,15 @@ def test_treespec_repr(self):
967
976
tuple ,
968
977
None ,
969
978
[
970
- python_leafspec ,
971
- python_leafspec ,
979
+ python_pytree . treespec_leaf () ,
980
+ python_pytree . treespec_leaf () ,
972
981
python_pytree .TreeSpec (
973
982
list ,
974
983
None ,
975
- [python_leafspec , python_leafspec ],
984
+ [
985
+ python_pytree .treespec_leaf (),
986
+ python_pytree .treespec_leaf (),
987
+ ],
976
988
),
977
989
],
978
990
),
@@ -986,12 +998,12 @@ def test_treespec_repr(self):
986
998
python_pytree .TreeSpec (
987
999
list ,
988
1000
None ,
989
- [python_leafspec , python_leafspec ],
1001
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
990
1002
),
991
1003
python_pytree .TreeSpec (
992
1004
list ,
993
1005
None ,
994
- [python_leafspec , python_leafspec ],
1006
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
995
1007
),
996
1008
python_pytree .TreeSpec (dict , [], []),
997
1009
],
@@ -1000,7 +1012,7 @@ def test_treespec_repr(self):
1000
1012
python_pytree .TreeSpec (
1001
1013
python_pytree .structseq ,
1002
1014
torch .return_types .sort ,
1003
- [python_leafspec , python_leafspec ],
1015
+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
1004
1016
),
1005
1017
],
1006
1018
)
@@ -1026,7 +1038,7 @@ def test_pytree_serialize_defaultdict_enum(self):
1026
1038
list ,
1027
1039
None ,
1028
1040
[
1029
- python_leafspec ,
1041
+ python_pytree . treespec_leaf () ,
1030
1042
],
1031
1043
),
1032
1044
],
@@ -1035,7 +1047,7 @@ def test_pytree_serialize_defaultdict_enum(self):
1035
1047
self .assertIsInstance (serialized_spec , str )
1036
1048
1037
1049
def test_pytree_serialize_enum (self ):
1038
- spec = python_pytree .TreeSpec (dict , TestEnum .A , [python_leafspec ])
1050
+ spec = python_pytree .TreeSpec (dict , TestEnum .A , [python_pytree . treespec_leaf () ])
1039
1051
1040
1052
serialized_spec = python_pytree .treespec_dumps (spec )
1041
1053
self .assertIsInstance (serialized_spec , str )
@@ -1198,12 +1210,20 @@ def test_saved_serialized(self):
1198
1210
OrderedDict ,
1199
1211
[1 , 2 , 3 ],
1200
1212
[
1201
- python_pytree .TreeSpec (tuple , None , [python_leafspec , python_leafspec ]),
1202
- python_leafspec ,
1213
+ python_pytree .TreeSpec (
1214
+ tuple ,
1215
+ None ,
1216
+ [python_pytree .treespec_leaf (), python_pytree .treespec_leaf ()],
1217
+ ),
1218
+ python_pytree .treespec_leaf (),
1203
1219
python_pytree .TreeSpec (
1204
1220
dict ,
1205
1221
[4 , 5 , 6 ],
1206
- [python_leafspec , python_leafspec , python_leafspec ],
1222
+ [
1223
+ python_pytree .treespec_leaf (),
1224
+ python_pytree .treespec_leaf (),
1225
+ python_pytree .treespec_leaf (),
1226
+ ],
1207
1227
),
1208
1228
],
1209
1229
)
@@ -1488,7 +1508,7 @@ def setUp(self):
1488
1508
raise unittest .SkipTest ("C++ pytree tests are not supported in fbcode" )
1489
1509
1490
1510
def test_treespec_equality (self ):
1491
- self .assertEqual (cxx_pytree .LeafSpec (), cxx_pytree .LeafSpec ())
1511
+ self .assertEqual (cxx_pytree .treespec_leaf (), cxx_pytree .treespec_leaf ())
1492
1512
1493
1513
def test_treespec_repr (self ):
1494
1514
# Check that it looks sane
0 commit comments