8000 Update base for Update on "[pytree] add key path api" · pytorch/pytorch@82f083a · GitHub
[go: up one dir, main page]

Skip to content

Commit 82f083a

Browse files
committed
Update base for Update on "[pytree] add key path api"
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
1 parent 940ed46 commit 82f083a

File tree

0 file changed

+0
-0
lines changed

    0 file changed

    +0
    -0
    lines changed

    0 commit comments

    Comments
     (0)
    0