8000 Update on "[pytree] add key path api" · pytorch/pytorch@1068392 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1068392

Browse files
committed
Update on "[pytree] add key path api"
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
2 parents 66f15fe + 97b2c9d commit 1068392

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

torch/utils/_cxx_pytree.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -840,36 +840,53 @@ def __new__(cls) -> "LeafSpec":
840840
return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]
841841

842842

843-
def tree_flatten_with_path(tree: PyTree) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
844-
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
843+
def tree_leaves_with_path(
844+
tree: PyTree,
845+
is_leaf: Optional[Callable[[PyTree], bool]] = None,
846+
) -> List[Tuple[KeyPath, Any]]:
847+
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
845848
846849
Args:
847-
tree: a pytree to flatten. If it contains a custom type, that type must be
850+
tree: a pytree. If it contains a custom type, that type must be
848851
registered with an appropriate `tree_flatten_with_path_fn` when registered
849852
with :func:`register_pytree_node`.
853+
is_leaf: An extra leaf predicate function that will be called at each
854+
flattening step. The function should have a single argument with signature
855+
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
856+
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
857+
leaf or not. If the function is not specified, the default pytree registry will be used.
850858
Returns:
851-
A tuple where the first element is a list of (key path, leaf) pairs, and the
852-
second element is a :class:`TreeSpec` representing the structure of the flattened
853-
tree.
859+
A list of (key path, leaf) pairs.
854860
"""
855861
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
856862

857863

858-
def tree_leaves_with_path(tree: PyTree) -> List[Tuple[KeyPath, Any]]:
864+
def tree_leaves_with_path(
865+
tree: PyTree,
866+
is_leaf: Optional[Callable[[PyTree], bool]] = None,
867+
) -> List[Tuple[KeyPath, Any]]:
859868
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
860869
861870
Args:
862871
tree: a pytree. If it contains a custom type, that type must be
863872
registered with an appropriate `tree_flatten_with_path_fn` when registered
864873
with :func:`register_pytree_node`.
874+
is_leaf: An extra leaf predicate function that will be called at each
875+
flattening step. The function should have a single argument with signature
876+
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
877+
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
878+
leaf or not. If the function is not specified, the default pytree registry will be used.
865879
Returns:
866880
A list of (key path, leaf) pairs.
867881
"""
868882
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
869883

870884

871885
def tree_map_with_path(
872-
func: Callable[..., Any], tree: PyTree, *rests: PyTree
886+
func: Callable[..., Any],
887+
tree: PyTree,
888+
*rests: PyTree,
889+
is_leaf: Optional[Callable[[PyTree], bool]] = None,
873890
) -> PyTree:
874891
"""Like :func:`tree_map`, but the provided callable takes an additional key path argument.
875892
@@ -882,6 +899,11 @@ def tree_map_with_path(
882899
argument to function ``func``.
883900
rests: A tuple of pytrees, each of which has the same structure as
884901
``tree`` or has ``tree`` as a prefix.
902+
is_leaf: An extra leaf predicate function that will be called at each
903+
flattening step. The function should have a single argument with signature
904+
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
905+
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
906+
leaf or not. If the function is not specified, the default pytree registry will be used.
885907
886908
Returns
887909
A new pytree with the same structure as ``tree`` but with the value at each leaf given by

0 commit comments

Comments
 (0)
0