@@ -720,6 +720,236 @@ def test_broadcast_to_and_flatten(self, pytree_impl):
720
720
result = pytree_impl ._broadcast_to_and_flatten (pytree , to_spec )
721
721
self .assertEqual (result , expected , msg = str ([pytree , to_spec , expected ]))
722
722
723
+ @parametrize (
724
+ "pytree_impl" ,
725
+ [
726
+ subtest (py_pytree , name = "py" ),
727
+ subtest (cxx_pytree , name = "cxx" ),
728
+ ],
729
+ )
730
+ def test_tree_map_with_path (self , pytree_impl ):
731
+ tree = [{i : i for i in range (10 )}]
732
+ all_zeros = pytree_impl .tree_map_with_path (
733
+ lambda kp , val : val - kp [1 ].key + kp [0 ].idx , tree
734
+ )
735
+ self .assertEqual (all_zeros , [dict .fromkeys (range (10 ), 0 )])
736
+
737
+ @parametrize (
738
+ "pytree_impl" ,
739
+ [
740
+ subtest (py_pytree , name = "py" ),
741
+ subtest (cxx_pytree , name = "cxx" ),
742
+ ],
743
+ )
744
+ def test_tree_map_with_path_multiple_trees (self , pytree_impl ):
745
+ @dataclass
746
+ class ACustomPytree :
747
+ x : Any
748
+ y : Any
749
+ z : Any
750
+
751
+ tree1 = [ACustomPytree (x = 12 , y = {"cin" : [1 , 4 , 10 ], "bar" : 18 }, z = "leaf" ), 5 ]
752
+ tree2 = [ACustomPytree (x = 2 , y = {"cin" : [2 , 2 , 2 ], "bar" : 2 }, z = "leaf" ), 2 ]
753
+
754
+ pytree_impl .register_pytree_node (
755
+ ACustomPytree ,
756
+ flatten_fn = lambda f : ([f .x , f .y ], f .z ),
757
+ unflatten_fn = lambda xy , z : ACustomPytree (xy [0 ], xy [1 ], z ),
758
+ flatten_with_keys_fn = lambda f : (
759
+ (
760
+ (pytree_impl .GetAttrKey ("x" ), f .x ),
761
+ (pytree_impl .GetAttrKey ("y" ), f .y ),
762
+ ),
763
+ f .z ,
764
+ ),
765
+ )
766
+ from_two_trees = pytree_impl .tree_map_with_path (
767
+ lambda kp , a , b : a + b , tree1 , tree2
768
+ )
769
+ from_one_tree = pytree_impl .tree_map (lambda a : a + 2 , tree1 )
770
+ self .assertEqual (from_two_trees , from_one_tree )
771
+
772
+ @skipIfTorchDynamo ("dynamo pytree tracing doesn't work here" )
773
+ @parametrize (
774
+ "pytree_impl" ,
775
+ [
776
+ subtest (py_pytree , name = "py" ),
777
+ subtest (cxx_pytree , name = "cxx" ),
778
+ ],
779
+ )
780
+ def test_tree_flatten_with_path_is_leaf (self , pytree_impl ):
781
+ leaf_dict = {"foo" : [(3 )]}
782
+ pytree = (["hello" , [1 , 2 ], leaf_dict ],)
783
+ key_leaves , spec = pytree_impl .tree_flatten_with_path (
784
+ pytree , is_leaf = lambda x : isinstance (x , dict )
785
+ )
786
+ self .assertTrue (key_leaves [- 1 ][1 ] is leaf_dict )
787
+
788
+ @parametrize (
789
+ "pytree_impl" ,
790
+ [
791
+ subtest (py_pytree , name = "py" ),
792
+ subtest (cxx_pytree , name = "cxx" ),
793
+ ],
794
+ )
795
+ def test_tree_flatten_with_path_roundtrip (self , pytree_impl ):
796
+ class ANamedTuple (NamedTuple ):
797
+ x : torch .Tensor
798
+ y : int
799
+ z : str
800
+
801
+ @dataclass
802
+ class ACustomPytree :
803
+ x : Any
804
+ y : Any
805
+ z : Any
806
+
807
+ pytree_impl .register_pytree_node (
808
+ ACustomPytree ,
809
+ flatten_fn = lambda f : ([f .x , f .y ], f .z ),
810
+ unflatten_fn = lambda xy , z : ACustomPytree (xy [0 ], xy [1 ], z ),
811
+ flatten_with_keys_fn = lambda f : (
812
+ (
813
+ (pytree_impl .GetAttrKey ("x" ), f .x ),
814
+ (pytree_impl .GetAttrKey ("y" ), f .y ),
815
+ ),
816
+ f .z ,
817
+ ),
818
+ )
819
+
820
+ SOME_PYTREES = [
821
+ (None ,),
822
+ ["hello" , [1 , 2 ], {"foo" : [(3 )]}],
823
+ [ANamedTuple (x = torch .rand (2 , 3 ), y = 1 , z = "foo" )],
824
+ [ACustomPytree (x = 12 , y = {"cin" : [1 , 4 , 10 ], "bar" : 18 }, z = "leaf" ), 5 ],
825
+ ]
826
+ for pytree in SOME_PYTREES :
827
+ key_leaves , spec = pytree_impl .tree_flatten_with_path (pytree )
828
+ actual = pytree_impl .tree_unflatten ([leaf for _ , leaf in key_leaves ], spec )
829
+ self .assertEqual (actual , pytree )
830
+
831
+ @parametrize (
832
+ "pytree_impl" ,
833
+ [
834
+ subtest (py_pytree , name = "py" ),
835
+ subtest (cxx_pytree , name = "cxx" ),
836
+ ],
837
+ )
838
+ def test_tree_leaves_with_path (self , pytree_impl ):
839
+ class ANamedTuple (NamedTuple ):
840
+ x : torch .Tensor
841
+ y : int
842
+ z : str
843
+
844
+ @dataclass
845
+ class ACustomPytree :
846
+ x : Any
847
+ y : Any
848
+ z : Any
849
+
850
+ pytree_impl .register_pytree_node (
851
+ ACustomPytree ,
852
+ flatten_fn = lambda f : ([f .x , f .y ], f .z ),
853
+ unflatten_fn = lambda xy , z : ACustomPytree (xy [0 ], xy [1 ], z ),
854
+ flatten_with_keys_fn = lambda f : (
855
+ (
856
+ (pytree_impl .GetAttrKey ("x" ), f .x ),
857
+ (pytree_impl .GetAttrKey ("y" ), f .y ),
858
+ ),
859
+ f .z ,
860
+ ),
861
+ )
862
+
863
+ SOME_PYTREES = [
864
+ (None ,),
865
+ ["hello" , [1 , 2 ], {"foo" : [(3 )]}],
866
+ [ANamedTuple (x = torch .rand (2 , 3 ), y = 1 , z = "foo" )],
867
+ [ACustomPytree (x = 12 , y = {"cin" : [1 , 4 , 10 ], "bar" : 18 }, z = "leaf" ), 5 ],
868
+ ]
869
+ for pytree in
F42D
SOME_PYTREES :
870
+ flat_out , _ = pytree_impl .tree_flatten_with_path (pytree )
871
+ leaves_out = pytree_impl .tree_leaves_with_path (pytree )
872
+ self .assertEqual (flat_out , leaves_out )
873
+
874
+ @parametrize (
875
+ "pytree_impl" ,
876
+ [
877
+ subtest (py_pytree , name = "py" ),
878
+ subtest (cxx_pytree , name = "cxx" ),
879
+ ],
880
+ )
881
+ def test_key_str (self , pytree_impl ):
882
+ class ANamedTuple (NamedTuple ):
883
+ x : str
884
+ y : int
885
+
886
+ tree = (["hello" , [1 , 2 ], {"foo" : [(3 )], "bar" : [ANamedTuple (x = "baz" , y = 10 )]}],)
887
+ flat , _ = pytree_impl .tree_flatten_with_path (tree )
888
+ paths = [f"{ pytree_impl .keystr (kp )} : { val } " for kp , val in flat ]
889
+ self .assertEqual (
890
+ paths ,
891
+ [
892
+ "[0][0]: hello" ,
893
+ "[0][1][0]: 1" ,
894
+ "[0][1][1]: 2" ,
895
+ "[0][2]['foo'][0]: 3" ,
896
+ "[0][2]['bar'][0].x: baz" ,
897
+ "[0][2]['bar'][0].y: 10" ,
898
+ ],
899
+ )
900
+
901
+ @skipIfTorchDynamo ("AssertionError in dynamo" )
902
+ @parametrize (
903
+ "pytree_impl" ,
904
+ [
905
+ subtest (py_pytree , name = "py" ),
906
+ subtest (cxx_pytree , name = "cxx" ),
907
+ ],
908
+ )
909
+ def test_flatten_flatten_with_key_consistency (self , pytree_impl ):
910
+ """Check that flatten and flatten_with_key produces consistent leaves/context."""
911
+ reg = py_pytree .SUPPORTED_NODES
912
+
913
+ EXAMPLE_TREE = {
914
+ list : [1 , 2 , 3 ],
915
+ tuple : (1 , 2 , 3 ),
916
+ dict : {"foo" : 1 , "bar" : 2 },
917
+ namedtuple : collections .namedtuple ("ANamedTuple" , ["x" , "y" ])(1 , 2 ),
918
+ OrderedDict : OrderedDict ([("foo" , 1 ), ("bar" , 2 )]),
919
+ defaultdict : defaultdict (int , {"foo" : 1 , "bar" : 2 }),
920
+ deque : deque ([1 , 2 , 3 ]),
921
+ torch .Size : torch .Size ([1 , 2 , 3 ]),
922
+ immutable_dict : immutable_dict ({"foo" : 1 , "bar" : 2 }),
923
+ immutable_list : immutable_list ([1 , 2 , 3 ]),
924
+ }
925
+
926
+ for typ in reg :
927
+ example = EXAMPLE_TREE .get (typ )
928
+ if example is None :
929
+ continue
930
+ flat_with_path , spec1 = pytree_impl .tree_flatten_with_path (example )
931
+ flat , spec2 = pytree_impl .tree_flatten (example )
932
+
933
+ self .assertEqual (flat , [x [1 ] for x in flat_with_path ])
934
+ self .assertEqual (spec1 , spec2 )
935
+
936
+ @parametrize (
937
+ "pytree_impl" ,
938
+ [
939
+ subtest (py_pytree , name = "py" ),
940
+ subtest (cxx_pytree , name = "cxx" ),
941
+ ],
942
+ )
943
+ def test_key_access (self , pytree_impl ):
944
+ class ANamedTuple (NamedTuple ):
945
+ x : str
946
+ y : int
947
+
948
+ tree = (["hello" , [1 , 2 ], {"foo" : [(3 )], "bar" : [ANamedTuple (x = "baz" , y = 10 )]}],)
949
+ flat , _ = pytree_impl .tree_flatten_with_path (tree )
950
+ for kp , val in flat :
951
+ self .assertEqual (pytree_impl .key_get (tree , kp ), val )
952
+
723
953
@parametrize (
724
954
"pytree_impl" ,
725
955
[
@@ -1160,162 +1390,6 @@ def test_saved_serialized(self):
1160
1390
self .assertEqual (serialized_spec , saved_spec )
1161
1391
self .assertEqual (complicated_spec , py_pytree .treespec_loads (saved_spec ))
1162
1392
1163
- def test_tree_map_with_path (self ):
1164
- tree = [{i : i for i in range (10 )}]
1165
- all_zeros = py_pytree .tree_map_with_path (
1166
- lambda kp , val : val - kp [1 ].key + kp [0 ].idx , tree
1167
- )
1168
- self .assertEqual (all_zeros , [dict .fromkeys (range (10 ), 0 )])
1169
-
1170
- def test_tree_map_with_path_multiple_trees (self ):
1171
- @dataclass
1172
- class ACustomPytree :
1173
- x : Any
1174
- y : Any
1175
- z : Any
1176
-
1177
- tree1 = [ACustomPytree (x = 12 , y = {"cin" : [1 , 4 , 10 ], "bar" : 18 }, z = "leaf" ), 5 ]
1178
- tree2 = [ACustomPytree (x = 2 , y = {"cin" : [2 , 2 , 2 ], "bar" : 2 }, z = "leaf" ), 2 ]
1179
-
1180
- py_pytree .register_pytree_node (
1181
- ACustomPytree ,
1182
- flatten_fn = lambda f : ([f .x , f .y ], f .z ),
1183
- unflatten_fn = lambda xy , z : ACustomPytree (xy [0 ], xy [1 ], z ),
1184
- flatten_with_keys_fn = lambda f : ((("x" , f .x ), ("y" , f .y )), f .z ),
1185
- )
1186
- from_two_trees = py_pytree .tree_map_with_path (
1187
- lambda kp , a , b : a + b , tree1 , tree2
1188
- )
1189
- from_one_tree = py_pytree .tree_map (lambda a : a + 2 , tree1 )
1190
- self .assertEqual (from_two_trees , from_one_tree )
1191
-
1192
- @skipIfTorchDynamo ("dynamo pytree tracing doesn't work here" )
1193
- def test_tree_flatten_with_path_is_leaf (self ):
1194
- leaf_dict = {"foo" : [(3 )]}
1195
- pytree = (["hello" , [1 , 2 ], leaf_dict ],)
1196
- key_leaves , _ = py_pytree .tree_flatten_with_path (
1197
- pytree , is_leaf = lambda x : isinstance (x , dict )
1198
- )
1199
- self .assertTrue (key_leaves [- 1 ][1 ] is leaf_dict )
1200
-
1201
- def test_tree_flatten_with_path_roundtrip (self ):
1202
- class ANamedTuple (NamedTuple ):
1203
- x : torch .Tensor
1204
- y : int
1205
- z : str
1206
-
1207
- @dataclass
1208
- class ACustomPytree :
1209
- x : Any
1210
- y : Any
1211
- z : Any
1212
-
1213
- py_pytree .register_pytree_node (
1214
- ACustomPytree ,
1215
- flatten_fn = lambda f : ([f .x , f .y ], f .z ),
1216
- unflatten_fn = lambda xy , z : ACustomPytree (xy [0 ], xy [1 ], z ),
1217
- flatten_with_keys_fn = lambda f : ((("x" , f .x ), ("y" , f .y )), f .z ),
1218
- )
1219
-
1220
- SOME_PYTREES = [
1221
- (None ,),
1222
- ["hello" , [1 , 2 ], {"foo" : [(3 )]}],
1223
- [ANamedTuple (x = torch .rand (2 , 3 ), y = 1 , z = "foo" )],
1224
- [ACustomPytree (x = 12 , y = {"cin" : [1 , 4 , 10 ], "bar" : 18 }, z = "leaf" ), 5 ],
1225
- ]
1226
- for pytree in SOME_PYTREES :
1227
- key_leaves , spec = py_pytree .tree_flatten_with_path (pytree )
1228
- actual = py_pytree .tree_unflatten ([leaf for _ , leaf in key_leaves ], spec )
1229
- self .assertEqual (actual , pytree )
1230
-
1231
- def test_tree_leaves_with_path (self ):
1232
- class ANamedTuple (NamedTuple ):
1233
- x : torch .Tensor
1234
- y : int
1235
- z : str
1236
-
1237
- @dataclass
1238
- class ACustomPytree :
1239
- x : Any
1240
- y : Any
1241
- z : Any
1242
-
1243
- py_pytree .register_pytree_node (
1244
- ACustomPytree ,
1245
- flatten_fn = lambda f : ([f .x , f .y ], f .z ),
1246
- unflatten_fn = lambda xy , z : ACustomPytree (xy [0 ], xy [1 ], z ),
1247
- flatten_with_keys_fn = lambda f : ((("x" , f .x ), ("y" , f .y )), f .z ),
1248
- )
1249
-
1250
- SOME_PYTREES = [
1251
- (None ,),
1252
- ["hello" , [1 , 2 ], {"foo" : [(3 )]}],
1253
- [ANamedTuple (x = torch .rand (2 , 3 ), y = 1 , z = "foo" )],
1254
- [ACustomPytree (x = 12 , y = {"cin" : [1 , 4 , 10 ], "bar" : 18 }, z = "leaf" ), 5 ],
1255
- ]
1256
- for pytree in SOME_PYTREES :
1257
- flat_out , _ = py_pytree .tree_flatten_with_path (pytree )
1258
- leaves_out = py_pytree .tree_leaves_with_path (pytree )
1259
- self .assertEqual (flat_out , leaves_out )
1260
-
1261
- def test_key_str (self ):
1262
- class ANamedTuple (NamedTuple ):
1263
- x : str
1264
- y : int
1265
-
1266
- tree = (["hello" , [1 , 2 ], {"foo" : [(3 )], "bar" : [ANamedTuple (x = "baz" , y = 10 )]}],)
1267
- flat , _ = py_pytree .tree_flatten_with_path (tree )
1268
- paths = [f"{ py_pytree .keystr (kp )} : { val } " for kp , val in flat ]
1269
- self .assertEqual (
1270
- paths ,
1271
- [
1272
- "[0][0]: hello" ,
1273
- "[0][1][0]: 1" ,
1274
- "[0][1][1]: 2" ,
1275
- "[0][2]['foo'][0]: 3" ,
1276
- "[0][2]['bar'][0].x: baz" ,
1277
- "[0][2]['bar'][0].y: 10" ,
1278
- ],
1279
- )
1280
-
1281
- @skipIfTorchDynamo ("AssertionError in dynamo" )
1282
- def test_flatten_flatten_with_key_consistency (self ):
1283
- """Check that flatten and flatten_with_key produces consistent leaves/context."""
1284
- reg = py_pytree .SUPPORTED_NODES
1285
-
1286
- EXAMPLE_TREE = {
1287
- list : [1 , 2 , 3 ],
1288
- tuple : (1 , 2 , 3 ),
1289
- dict : {"foo" : 1 , "bar" : 2 },
1290
- namedtuple : collections .namedtuple ("ANamedTuple" , ["x" , "y" ])(1 , 2 ),
1291
- OrderedDict : OrderedDict ([("foo" , 1 ), ("bar" , 2 )]),
1292
- defaultdict : defaultdict (int , {"foo" : 1 , "bar" : 2 }),
1293
- deque : deque ([1 , 2 , 3 ]),
1294
- torch .Size : torch .Size ([1 , 2 , 3 ]),
1295
- immutable_dict : immutable_dict ({"foo" : 1 , "bar" : 2 }),
1296
- immutable_list : immutable_list ([1 , 2 , 3 ]),
1297
- }
1298
-
1299
- for typ in reg :
1300
- example = EXAMPLE_TREE .get (typ )
1301
- if example is None :
1302
- continue
1303
- flat_with_path , spec1 = py_pytree .tree_flatten_with_path (example )
1304
- flat , spec2 = py_pytree .tree_flatten (example )
1305
-
1306
- self .assertEqual (flat , [x [1 ] for x in flat_with_path ])
1307
- self .assertEqual (spec1 , spec2 )
1308
-
1309
- def test_key_access (self ):
1310
- class ANamedTuple (NamedTuple ):
1311
- x : str
1312
- y : int
1313
-
1314
- tree = (["hello" , [1 , 2 ], {"foo" : [(3 )], "bar" : [ANamedTuple (x = "baz" , y = 10 )]}],)
1315
- flat , _ = py_pytree .tree_flatten_with_path (tree )
1316
- for kp , val in flat :
1317
- self .assertEqual (py_pytree .key_get (tree , kp ), val )
1318
-
1319
1393
1320
1394
class TestCxxPytree (TestCase ):
1321
1395
def setUp (self ):
0 commit comments