8000 [pytree] implement key path APIs for CXX pytree · XuehaiPan/pytorch@b392c9c · GitHub
[go: up one dir, main page]

Skip to content

Commit b392c9c

Browse files
committed
[pytree] implement key path APIs for CXX pytree
ghstack-source-id: 3b4f24f Pull Request resolved: pytorch#130141
1 parent e75514c commit b392c9c

File tree

6 files changed

+367
-201
lines changed

6 files changed

+367
-201
lines changed

test/test_pytree.py

Lines changed: 230 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,236 @@ def test_broadcast_to_and_flatten(self, pytree_impl):
720720
result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec)
721721
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
722722

723+
@parametrize(
724+
"pytree_impl",
725+
[
726+
subtest(py_pytree, name="py"),
727+
subtest(cxx_pytree, name="cxx"),
728+
],
729+
)
730+
def test_tree_map_with_path(self, pytree_impl):
731+
tree = [{i: i for i in range(10)}]
732+
all_zeros = pytree_impl.tree_map_with_path(
733+
lambda kp, val: val - kp[1].key + kp[0].idx, tree
734+
)
735+
self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
736+
737+
@parametrize(
738+
"pytree_impl",
739+
[
740+
subtest(py_pytree, name="py"),
741+
subtest(cxx_pytree, name="cxx"),
742+
],
743+
)
744+
def test_tree_map_with_path_multiple_trees(self, pytree_impl):
745+
@dataclass
746+
class ACustomPytree:
747+
x: Any
748+
y: Any
749+
z: Any
750+
751+
tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5]
752+
tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2]
753+
754+
pytree_impl.register_pytree_node(
755+
ACustomPytree,
756+
flatten_fn=lambda f: ([f.x, f.y], f.z),
757+
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
758+
flatten_with_keys_fn=lambda f: (
759+
(
760+
(pytree_impl.GetAttrKey("x"), f.x),
761+
(pytree_impl.GetAttrKey("y"), f.y),
762+
),
763+
f.z,
764+
),
765+
)
766+
from_two_trees = pytree_impl.tree_map_with_path(
767+
lambda kp, a, b: a + b, tree1, tree2
768+
)
769+
from_one_tree = pytree_impl.tree_map(lambda a: a + 2, tree1)
770+
self.assertEqual(from_two_trees, from_one_tree)
771+
772+
@skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
773+
@parametrize(
774+
"pytree_impl",
775+
[
776+
subtest(py_pytree, name="py"),
777+
subtest(cxx_pytree, name="cxx"),
778+
],
779+
)
780+
def test_tree_flatten_with_path_is_leaf(self, pytree_impl):
781+
leaf_dict = {"foo": [(3)]}
782+
pytree = (["hello", [1, 2], leaf_dict],)
783+
key_leaves, spec = pytree_impl.tree_flatten_with_path(
784+
pytree, is_leaf=lambda x: isinstance(x, dict)
785+
)
786+
self.assertTrue(key_leaves[-1][1] is leaf_dict)
787+
788+
@parametrize(
789+
"pytree_impl",
790+
[
791+
subtest(py_pytree, name="py"),
792+
subtest(cxx_pytree, name="cxx"),
793+
],
794+
)
795+
def test_tree_flatten_with_path_roundtrip(self, pytree_impl):
796+
class ANamedTuple(NamedTuple):
797+
x: torch.Tensor
798+
y: int
799+
z: str
800+
801+
@dataclass
802+
class ACustomPytree:
803+
x: Any
804+
y: Any
805+
z: Any
806+
807+
pytree_impl.register_pytree_node(
808+
ACustomPytree,
809+
flatten_fn=lambda f: ([f.x, f.y], f.z),
810+
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
811+
flatten_with_keys_fn=lambda f: (
812+
(
813+
(pytree_impl.GetAttrKey("x"), f.x),
814+
(pytree_impl.GetAttrKey("y"), f.y),
815+
),
816+
f.z,
817+
),
818+
)
819+
820+
SOME_PYTREES = [
821+
(None,),
822+
["hello", [1, 2], {"foo": [(3)]}],
823+
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
824+
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
825+
]
826+
for pytree in SOME_PYTREES:
827+
key_leaves, spec = pytree_impl.tree_flatten_with_path(pytree)
828+
actual = pytree_impl.tree_unflatten([leaf for _, leaf in key_leaves], spec)
829+
self.assertEqual(actual, pytree)
830+
831+
@parametrize(
832+
"pytree_impl",
833+
[
834+
subtest(py_pytree, name="py"),
835+
subtest(cxx_pytree, name="cxx"),
836+
],
837+
)
838+
def test_tree_leaves_with_path(self, pytree_impl):
839+
class ANamedTuple(NamedTuple):
840+
x: torch.Tensor
841+
y: int
842+
z: str
843+
844+
@dataclass
845+
class ACustomPytree:
846+
x: Any
847+
y: Any
848+
z: Any
849+
850+
pytree_impl.register_pytree_node(
851+
ACustomPytree,
852+
flatten_fn=lambda f: ([f.x, f.y], f.z),
853+
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
854+
flatten_with_keys_fn=lambda f: (
855+
(
856+
(pytree_impl.GetAttrKey("x"), f.x),
857+
(pytree_impl.GetAttrKey("y"), f.y),
858+
),
859+
f.z,
860+
),
861+
)
862+
863+
SOME_PYTREES = [
864+
(None,),
865+
["hello", [1, 2], {"foo": [(3)]}],
866+
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
867+
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
868+
]
869+
for pytree in F42D SOME_PYTREES:
870+
flat_out, _ = pytree_impl.tree_flatten_with_path(pytree)
871+
leaves_out = pytree_impl.tree_leaves_with_path(pytree)
872+
self.assertEqual(flat_out, leaves_out)
873+
874+
@parametrize(
875+
"pytree_impl",
876+
[
877+
subtest(py_pytree, name="py"),
878+
subtest(cxx_pytree, name="cxx"),
879+
],
880+
)
881+
def test_key_str(self, pytree_impl):
882+
class ANamedTuple(NamedTuple):
883+
x: str
884+
y: int
885+
886+
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
887+
flat, _ = pytree_impl.tree_flatten_with_path(tree)
888+
paths = [f"{pytree_impl.keystr(kp)}: {val}" for kp, val in flat]
889+
self.assertEqual(
890+
paths,
891+
[
892+
"[0][0]: hello",
893+
"[0][1][0]: 1",
894+
"[0][1][1]: 2",
895+
"[0][2]['foo'][0]: 3",
896+
"[0][2]['bar'][0].x: baz",
897+
"[0][2]['bar'][0].y: 10",
898+
],
899+
)
900+
901+
@skipIfTorchDynamo("AssertionError in dynamo")
902+
@parametrize(
903+
"pytree_impl",
904+
[
905+
subtest(py_pytree, name="py"),
906+
subtest(cxx_pytree, name="cxx"),
907+
],
908+
)
909+
def test_flatten_flatten_with_key_consistency(self, pytree_impl):
910+
"""Check that flatten and flatten_with_key produces consistent leaves/context."""
911+
reg = py_pytree.SUPPORTED_NODES
912+
913+
EXAMPLE_TREE = {
914+
list: [1, 2, 3],
915+
tuple: (1, 2, 3),
916+
dict: {"foo": 1, "bar": 2},
917+
namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
918+
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
919+
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
920+
deque: deque([1, 2, 3]),
921+
torch.Size: torch.Size([1, 2, 3]),
922+
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
923+
immutable_list: immutable_list([1, 2, 3]),
924+
}
925+
926+
for typ in reg:
927+
example = EXAMPLE_TREE.get(typ)
928+
if example is None:
929+
continue
930+
flat_with_path, spec1 = pytree_impl.tree_flatten_with_path(example)
931+
flat, spec2 = pytree_impl.tree_flatten(example)
932+
933+
self.assertEqual(flat, [x[1] for x in flat_with_path])
934+
self.assertEqual(spec1, spec2)
935+
936+
@parametrize(
937+
"pytree_impl",
938+
[
939+
subtest(py_pytree, name="py"),
940+
subtest(cxx_pytree, name="cxx"),
941+
],
942+
)
943+
def test_key_access(self, pytree_impl):
944+
class ANamedTuple(NamedTuple):
945+
x: str
946+
y: int
947+
948+
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
949+
flat, _ = pytree_impl.tree_flatten_with_path(tree)
950+
for kp, val in flat:
951+
self.assertEqual(pytree_impl.key_get(tree, kp), val)
952+
723953
@parametrize(
724954
"pytree_impl",
725955
[
@@ -1160,162 +1390,6 @@ def test_saved_serialized(self):
11601390
self.assertEqual(serialized_spec, saved_spec)
11611391
self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))
11621392

1163-
def test_tree_map_with_path(self):
1164-
tree = [{i: i for i in range(10)}]
1165-
all_zeros = py_pytree.tree_map_with_path(
1166-
lambda kp, val: val - kp[1].key + kp[0].idx, tree
1167-
)
1168-
self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
1169-
1170-
def test_tree_map_with_path_multiple_trees(self):
1171-
@dataclass
1172-
class ACustomPytree:
1173-
x: Any
1174-
y: Any
1175-
z: Any
1176-
1177-
tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5]
1178-
tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2]
1179-
1180-
py_pytree.register_pytree_node(
1181-
ACustomPytree,
1182-
flatten_fn=lambda f: ([f.x, f.y], f.z),
1183-
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1184-
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1185-
)
1186-
from_two_trees = py_pytree.tree_map_with_path(
1187-
lambda kp, a, b: a + b, tree1, tree2
1188-
)
1189-
from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
1190-
self.assertEqual(from_two_trees, from_one_tree)
1191-
1192-
@skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
1193-
def test_tree_flatten_with_path_is_leaf(self):
1194-
leaf_dict = {"foo": [(3)]}
1195-
pytree = (["hello", [1, 2], leaf_dict],)
1196-
key_leaves, _ = py_pytree.tree_flatten_with_path(
1197-
pytree, is_leaf=lambda x: isinstance(x, dict)
1198-
)
1199-
self.assertTrue(key_leaves[-1][1] is leaf_dict)
1200-
1201-
def test_tree_flatten_with_path_roundtrip(self):
1202-
class ANamedTuple(NamedTuple):
1203-
x: torch.Tensor
1204-
y: int
1205-
z: str
1206-
1207-
@dataclass
1208-
class ACustomPytree:
1209-
x: Any
1210-
y: Any
1211-
z: Any
1212-
1213-
py_pytree.register_pytree_node(
1214-
ACustomPytree,
1215-
flatten_fn=lambda f: ([f.x, f.y], f.z),
1216-
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1217-
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1218-
)
1219-
1220-
SOME_PYTREES = [
1221-
(None,),
1222-
["hello", [1, 2], {"foo": [(3)]}],
1223-
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
1224-
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
1225-
]
1226-
for pytree in SOME_PYTREES:
1227-
key_leaves, spec = py_pytree.tree_flatten_with_path(pytree)
1228-
actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec)
1229-
self.assertEqual(actual, pytree)
1230-
1231-
def test_tree_leaves_with_path(self):
1232-
class ANamedTuple(NamedTuple):
1233-
x: torch.Tensor
1234-
y: int
1235-
z: str
1236-
1237-
@dataclass
1238-
class ACustomPytree:
1239-
x: Any
1240-
y: Any
1241-
z: Any
1242-
1243-
py_pytree.register_pytree_node(
1244-
ACustomPytree,
1245-
flatten_fn=lambda f: ([f.x, f.y], f.z),
1246-
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1247-
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1248-
)
1249-
1250-
SOME_PYTREES = [
1251-
(None,),
1252-
["hello", [1, 2], {"foo": [(3)]}],
1253-
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
1254-
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
1255-
]
1256-
for pytree in SOME_PYTREES:
1257-
flat_out, _ = py_pytree.tree_flatten_with_path(pytree)
1258-
leaves_out = py_pytree.tree_leaves_with_path(pytree)
1259-
self.assertEqual(flat_out, leaves_out)
1260-
1261-
def test_key_str(self):
1262-
class ANamedTuple(NamedTuple):
1263-
x: str
1264-
y: int
1265-
1266-
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
1267-
flat, _ = py_pytree.tree_flatten_with_path(tree)
1268-
paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat]
1269-
self.assertEqual(
1270-
paths,
1271-
[
1272-
"[0][0]: hello",
1273-
"[0][1][0]: 1",
1274-
"[0][1][1]: 2",
1275-
"[0][2]['foo'][0]: 3",
1276-
"[0][2]['bar'][0].x: baz",
1277-
"[0][2]['bar'][0].y: 10",
1278-
],
1279-
)
1280-
1281-
@skipIfTorchDynamo("AssertionError in dynamo")
1282-
def test_flatten_flatten_with_key_consistency(self):
1283-
"""Check that flatten and flatten_with_key produces consistent leaves/context."""
1284-
reg = py_pytree.SUPPORTED_NODES
1285-
1286-
EXAMPLE_TREE = {
1287-
list: [1, 2, 3],
1288-
tuple: (1, 2, 3),
1289-
dict: {"foo": 1, "bar": 2},
1290-
namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
1291-
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
1292-
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
1293-
deque: deque([1, 2, 3]),
1294-
torch.Size: torch.Size([1, 2, 3]),
1295-
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
1296-
immutable_list: immutable_list([1, 2, 3]),
1297-
}
1298-
1299-
for typ in reg:
1300-
example = EXAMPLE_TREE.get(typ)
1301-
if example is None:
1302-
continue
1303-
flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example)
1304-
flat, spec2 = py_pytree.tree_flatten(example)
1305-
1306-
self.assertEqual(flat, [x[1] for x in flat_with_path])
1307-
self.assertEqual(spec1, spec2)
1308-
1309-
def test_key_access(self):
1310-
class ANamedTuple(NamedTuple):
1311-
x: str
1312-
y: int
1313-
1314-
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
1315-
flat, _ = py_pytree.tree_flatten_with_path(tree)
1316-
for kp, val in flat:
1317-
self.assertEqual(py_pytree.key_get(tree, kp), val)
1318-
13191393

13201394
class TestCxxPytree(TestCase):
13211395
def setUp(self):

0 commit comments

Comments
 (0)
0