8000 [pytree] add key path api · pytorch/pytorch@b69cf5a · GitHub
[go: up one dir, main page]

Skip to content

Commit b69cf5a

Browse files
committed
[pytree] add key path api
Pull Request resolved: #116786 This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. ghstack-source-id: 211696842 @exported-using-ghexport Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
1 parent 83e8a07 commit b69cf5a

File tree

3 files changed

+365
-0
lines changed

3 files changed

+365
-0
lines changed

test/test_pytree.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import re
55
import unittest
66
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
7+
from dataclasses import dataclass
8+
from typing import Any, NamedTuple
79

810
import torch
911
import torch.utils._cxx_pytree as cxx_pytree
@@ -1006,6 +1008,92 @@ def test_saved_serialized(self):
10061008
self.assertEqual(serialized_spec, saved_spec)
10071009
self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))
10081010

1011+
def test_tree_map_with_path(self):
1012+
tree = [{i: i for i in range(10)}]
1013+
all_zeros = py_pytree.tree_map_with_path(
1014+
lambda kp, val: val - kp[1].key + kp[0].idx, tree
1015+
)
1016+
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
1017+
1018+
def test_tree_map_with_path_multiple_trees(self):
1019+
@dataclass
1020+
class ACustomPytree:
1021+
x: Any
1022+
y: Any
1023+
z: Any
1024+
1025+
tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5]
1026+
tree2 = [
1027+
ACustomPytree(
1028+
x=2,
1029+
y={"cin": [2, 2, 2], "bar": 2},
1030+
z="leaf",
1031+
),
1032+
2,
1033+
]
1034+
1035+
py_pytree.register_pytree_node(
1036+
ACustomPytree,
1037+
flatten_fn=lambda f: ([f.x, f.y], f.z),
1038+
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1039+
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1040+
)
1041+
from_two_trees = py_pytree.tree_map_with_path(
1042+
lambda kp, a, b: a + b, tree1, tree2
1043+
)
1044+
from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
1045+
self.assertEqual(from_two_trees, from_one_tree)
1046+
1047+
def test_tree_flatten_with_path_roundtrip(self):
1048+
class ANamedTuple(NamedTuple):
1049+
x: torch.Tensor
1050+
y: int
1051+
z: str
1052+
1053+
@dataclass
1054+
class ACustomPytree:
1055+
x: Any
1056+
y: Any
1057+
z: Any
1058+
1059+
py_pytree.register_pytree_node(
1060+
ACustomPytree,
1061+
flatten_fn=lambda f: ([f.x, f.y], f.z),
1062+
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1063+
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1064+
)
1065+
1066+
SOME_PYTREES = [
1067+
(None,),
1068+
["hello", [1, 2], {"foo": [(3)]}],
1069+
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
1070+
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
1071+
]
1072+
for pytree in SOME_PYTREES:
1073+
key_leaves, spec = py_pytree.tree_flatten_with_path(pytree)
1074+
actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec)
1075+
self.assertEqual(actual, pytree)
1076+
1077+
def test_key_str(self):
1078+
class ANamedTuple(NamedTuple):
1079+
x: str
1080+
y: int
1081+
1082+
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
1083+
flat, _ = py_pytree.tree_flatten_with_path(tree)
1084+
paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat]
1085+
self.assertEqual(
1086+
paths,
1087+
[
1088+
"[0][0]: hello",
1089+
"[0][1][0]: 1",
1090+
"[0][1][1]: 2",
1091+
"[0][2]['foo'][0]: 3",
1092+
"[0][2]['bar'][0].x: baz",
1093+
"[0][2]['bar'][0].y: 10",
1094+
],
1095+
)
1096+
10091097

10101098
class TestCxxPytree(TestCase):
10111099
def test_treespec_equality(self):

torch/utils/_cxx_pytree.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import optree
3636
from optree import PyTreeSpec # direct import for type annotations
3737

38+
from torch.utils._pytree import PHashable
39+
3840

3941
__all__ = [
4042
"PyTree",
@@ -46,12 +48,16 @@
4648
"FromDumpableContextFn",
4749
"TreeSpec",
4850
"LeafSpec",
51+
"keystr",
4952
"register_pytree_node",
5053
"tree_flatten",
54+
"tree_flatten_with_path",
5155
"tree_unflatten",
5256
"tree_leaves",
57+
"tree_leaves_with_path",
5358
"tree_structure",
5459
"tree_map",
60+
"tree_map_with_path",
5561
"tree_map_",
5662
"tree_map_only",
5763
"tree_map_only_",
@@ -80,6 +86,9 @@
8086
DumpableContext = Any # Any json dumpable text
8187
ToDumpableContextFn = Callable[[Context], DumpableContext]
8288
FromDumpableContextFn = Callable[[DumpableContext], Context]
89+
KeyEntry = PHashable
90+
KeyPath = Tuple[KeyEntry, ...]
91+
FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
8392

8493

8594
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@@ -98,6 +107,7 @@ def register_pytree_node(
98107
serialized_type_name: Optional[str] = None,
99108
to_dumpable_context: Optional[ToDumpableContextFn] = None,
100109
from_dumpable_context: Optional[FromDumpableContextFn] = None,
110+
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
101111
) -> None:
102112
"""Register a container-like type as pytree node.
103113
@@ -130,6 +140,9 @@ def register_pytree_node(
130140
... lambda children, _: set(children),
131141
... )
132142
"""
143+
if flatten_with_keys_fn is not None:
144+
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
145+
133146
_private_register_pytree_node(
134147
cls,
135148
flatten_fn,
@@ -738,3 +751,60 @@ def __instancecheck__(self, instance: object) -> bool:
738751
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
739752
def __new__(cls) -> "LeafSpec":
740753
return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]
754+
755+
756+
def tree_flatten_with_path(tree: PyTree) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
757+
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
758+
759+
Args:
760+
tree: a pytree to flatten. If it contains a custom type, that type must be
761+
registered with an appropriate `tree_flatten_with_path_fn` when registered
762+
with :func:`register_pytree_node`.
763+
Returns:
764+
A tuple where the first element is a list of (key path, leaf) pairs, and the
765+
second element is a :class:`TreeSpec` representing the structure of the flattened
766+
tree.
767+
"""
768+
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
769+
770+
771+
def tree_leaves_with_path(tree: PyTree) -> List[Tuple[KeyPath, Any]]:
772+
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
773+
774+
Args:
775+
tree: a pytree. If it contains a custom type, that type must be
776+
registered with an appropriate `tree_flatten_with_path_fn` when registered
777+
with :func:`register_pytree_node`.
778+
Returns:
779+
A list of (key path, leaf) pairs.
780+
"""
781+
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
782+
783+
784+
def tree_map_with_path(
785+
func: Callable[..., Any], tree: PyTree, *rests: PyTree
786+
) -> PyTree:
787+
"""Like :func:`tree_map`, but the provided callable takes an additional key path argument.
788+
789+
Args:
790+
func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
791+
corresponding leaves of the pytrees. The first positional argument
792+
to ``func`` is the key path of the leaf in question. The second
793+
positional argument is the value of the leaf.
794+
tree: A pytree to be mapped over, with each leaf providing the first positional
795+
argument to function ``func``.
796+
rests: A tuple of pytrees, each of which has the same structure as
797+
``tree`` or has ``tree`` as a prefix.
798+
799+
Returns
800+
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
801+
``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
802+
corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
803+
``xs`` is the tuple of values at corresponding nodes in ``rests``.
804+
"""
805+
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
806+
807+
808+
def keystr(kp: KeyPath) -> str:
809+
"""Give a key path, return a pretty-printed representation."""
810+
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")

0 commit comments

Comments
 (0)
0