8000 Update on "[pytree] add key path api" · pytorch/pytorch@f642001 · GitHub
[go: up one dir, main page]

Skip to content

Commit f642001

Browse files
committed
Update on "[pytree] add key path api"
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
2 parents ce206e9 + 784b56e commit f642001

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

test/test_pytree.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,36 @@ class ACustomPytree:
11341134
actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec)
11351135
self.assertEqual(actual, pytree)
11361136

1137+
def test_tree_leaves_with_path(self):
1138+
class ANamedTuple(NamedTuple):
1139+
x: torch.Tensor
1140+
y: int
1141+
z: str
1142+
1143+
@dataclass
1144+
class ACustomPytree:
1145+
x: Any
1146+
y: Any
1147+
z: Any
1148+
1149+
py_pytree.register_pytree_node(
1150+
ACustomPytree,
1151+
flatten_fn=lambda f: ([f.x, f.y], f.z),
1152+
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1153+
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1154+
)
1155+
1156+
SOME_PYTREES = [
1157+
(None,),
1158+
["hello", [1, 2], {"foo": [(3)]}],
1159+
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
1160+
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
1161+
]
1162+
for pytree in SOME_PYTREES:
1163+
flat_out, _ = py_pytree.tree_flatten_with_path(pytree)
1164+
leaves_out = py_pytree.tree_leaves_with_path(pytree)
1165+
self.assertEqual(flat_out, leaves_out)
1166+
11371167
def test_key_str(self):
11381168
class ANamedTuple(NamedTuple):
11391169
x: str

0 commit comments

Comments
 (0)
0