8000 [pytree] implement key path API · Issue #113378 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[pytree] implement key path API #113378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
suo opened this issue Nov 9, 2023 · 7 comments
Open

[pytree] implement key path API #113378

suo opened this issue Nov 9, 2023 · 7 comments
Assignees
Labels
module: pytree triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@suo
Copy link
Member
suo commented Nov 9, 2023

JAX has a key path API which makes it easy to index into a pytree. This would definitely be useful for more advanced users manipulation Pytrees; we have at least one internal use case for it already.

cc @zou3519

@suo suo added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: pytree labels Nov 9, 2023
@yanboliang
Copy link
Contributor

Signed up this as HAM project for one of our internal transfer candidate.

@stroxler
Copy link
Contributor

@XuehaiPan Curious your thoughts on this.

It looks like optree already has the main operations we want here:
map_with_path: https://github.com/metaopt/optree/blob/main/optree/ops.py#L566-L573
flatten_with_path: https://github.com/metaopt/optree/blob/main/optree/ops.py#L167-L173

If we're hoping to use optree as the backend for pytorch's pytree, would it make sense to focus on that first?

@XuehaiPan
Copy link
Collaborator
XuehaiPan commented Nov 22, 2023

It looks like optree already has the main operations we want here:
map_with_path: metaopt/optree@main/optree/ops.py#L566-L573
flatten_with_path: metaopt/optree@main/optree/ops.py#L167-L173

If we're hoping to use optree as the backend for pytorch's pytree, would it make sense to focus on that first?

JAX has a key path API which makes it easy to index into a pytree. This would definitely be useful for more advanced users manipulation Pytrees; we have at least one internal use case for it already.

The key path APIs in JAX pytree and optree are slightly different.

In [1]: from collections import *

In [2]: MyTuple = namedtuple('MyTuple', ['x', 'y', 'z'])

In [3]: tree = OrderedDict([('a', (1, 2)), ('b', {'c': [3, 4], 'd': 5}), ('e', MyTuple(6, 7, 8))])
In [4]: import jax.tree_util as jaxtree

In [5]: jaxtree.tree_flatten_with_path(tree)
Out[5]: 
(
    [
        ((DictKey(key='a'), SequenceKey(idx=0)), 1),
        ((DictKey(key='a'), SequenceKey(idx=1)), 2),
        ((DictKey(key='b'), DictKey(key='c'), SequenceKey(idx=0)), 3),
        ((DictKey(key='b'), DictKey(key='c'), SequenceKey(idx=1)), 4),
        ((DictKey(key='b'), DictKey(key='d')), 5),
        ((DictKey(key='e'), GetAttrKey(name='x')), 6),
        ((DictKey(key='e'), GetAttrKey(name='y')), 7),
        ((DictKey(key='e'), GetAttrKey(name='z')), 8)
    ],
    PyTreeDef(CustomNode(OrderedDict[('a', 'b', 'e')], [(*, *), {'c': [*, *], 'd': *}, CustomNode(namedtuple[MyTuple], [*, *, *])]))
)
In [6]: import optree

In [7]: optree.tree_flatten_with_path(tree)
Out[7]: 
(
    [('a', 0), ('a', 1), ('b', 'c', 0), ('b', 'c', 1), ('b', 'd'), ('e', 0), ('e', 1), ('e', 2)],
    [1, 2, 3, 4, 5, 6, 7, 8],
    PyTreeSpec(OrderedDict([('a', (*, *)), ('b', {'c': [*, *], 'd': *}), ('e', MyTuple(x=*, y=*, z=*))]))
)

In JAX's pytree, leaf nodes are identified using a tuple of KeyPath objects, including types like SequenceKey, DictKey, and GetAttrKey. Conversely, in optree, paths are pinpointed through a tuple consisting of raw indices and dictionary keys, without differentiating node types such as SequenceKey or DictKey.

In [8]: jaxtree.tree_flatten_with_path(tree)[0][0][0]
Out[8]: (DictKey(key='a'), SequenceKey(idx=0))

In [9]: optree.tree_flatten_with_path(tree)[0][0]
Out[9]: ('a', 0)

In [10]: tree['a'][0]
Out[10]: 1

Most container classes implement the __getitem__ method, so we can easily access the leaf via a given path:

In [11]: class itemgetter:
    ...:     def __init__(self, path):
    ...:         self.path = tuple(path)
    ...:
    ...:     def __call__(self, obj):
    ...:         for p in self.path:
    ...:             obj = obj[p]
    ...:         return obj

In [12]: getter = itemgetter(('a', 0))

In [13]: getter(tree)
Out[13]: 1

In [14]: paths = optree.tree_paths(tree)

In [15]: paths
Out[15]: 
[
    ('a', 0),
    ('a', 1),
    ('b', 'c', 0),
    ('b', 'c', 1),
    ('b', 'd'),
    ('e', 0),
    ('e', 1),
    ('e', 2)
]

In [16]: getters = [itemgetter(path) for path in paths]

In [17]: [getter(tree) for getter in getters]
Out[17]: [1, 2, 3, 4, 5, 6, 7, 8]

One thing I want to point out is that the JAX pytree uses .attr to access the member of namedtuples, while in optree we use raw index (int).

In [18]: jaxtree.tree_flatten_with_path(MyTuple(1, 2, 3))
Out[18]: 
(
    [
        ((GetAttrKey(name='x'),), 1),
        ((GetAttrKey(name='y'),), 2),
        ((GetAttrKey(name='z'),), 3)
    ],
    PyTreeDef(CustomNode(namedtuple[MyTuple], [*, *, *]))
)

In [19]: optree.tree_flatten_with_path(MyTuple(1, 2, 3))
Out[19]:
(
    [(0,), (1,), (2,)],
    [1, 2, 3],
    PyTreeSpec(MyTuple(x=*, y=*, z=*))
)

Any thoughts about the API design for the key paths?

8000

@stroxler
Copy link
Contributor
stroxler commented Nov 27, 2023

Thanks for the detailed answer!

Personally I find the Jax approach of wrapping the key in a semantic type a bit more future proof.

It's probably partly my preferences as someone who works mainly on a type-checker, but encoding the meaning of the key (in particular so that attribute keys - which are necessary to, for example, add support for some kinds of dataclasses - can be distinguished from map index keys) generally means the types are more expressive, and the system is easier to extend to new cases.

If you prefer optree's current API, I suspect it wouldn't be all that hard to add a wrapper layer for this

@XuehaiPan
Copy link
Collaborator
XuehaiPan commented Nov 28, 2023

Personally I find the Jax approach of wrapping the key in a semantic type a bit more future proof.

@stroxler I agree that the keys with semantic types in JAX are more expressive. In optree, the philosophy of the API design is to prioritize utilizing native Python types. This approach simplifies maintenance, enhances extensibility, and facilitates easier backward compatibility.


If you prefer optree's current API, I suspect it wouldn't be all that hard to add a wrapper layer for this

Would adding the node type in the path result resolve your concern?

For example:

  • path: tuple of raw index or key or custom entry
  • typed_path: tuple of (node type, raw index or key or custom entry) pairs

I think this would be more general than the semantic key in JAX. For example, we can distinguish the node type of dict or OrderedDict rather than always get the same semantic key DictKey(key='xxx').

In [1]: from collections import *

In [2]: MyTuple = namedtuple('MyTuple', ['x', 'y', 'z'])

In [3]: tree = OrderedDict([('a', (1, 2)), ('b', {'c': [3, 4], 'd': 5}), ('e', MyTuple(6, 7, 8))])

In [4]: import optree

In [5]: optree.tree_flatten_with_path(tree)
Out[5]: 
(
    [('a', 0), ('a', 1), ('b', 'c', 0), ('b', 'c', 1), ('b', 'd'), ('e', 0), ('e', 1), ('e', 2)],
    [1, 2, 3, 4, 5, 6, 7, 8],
    PyTreeSpec(OrderedDict([('a', (*, *)), ('b', {'c': [*, *], 'd': *}), ('e', MyTuple(x=*, y=*, z=*))]))
)

In [6]: optree.tree_flatten_with_typed_path(tree)
Out[6]: 
(
    [
        ((<class 'collections.OrderedDict'>, 'a'), (<class 'tuple'>, 0)),
        ((<class 'collections.OrderedDict'>, 'a'), (<class 'tuple'>, 1)),
        ((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'c'), (<class 'list'>, 0)),
        ((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'c'), (<class 'list'>, 1)),
        ((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'd')),
        ((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 0)),
        ((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 1)),
        ((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 2))
    ],
    [1, 2, 3, 4, 5, 6, 7, 8],
    PyTreeSpec(OrderedDict([('a', (*, *)), ('b', {'c': [*, *], 'd': *}), ('e', MyTuple(x=*, y=*, z=*))]))
)

An initial implementation for this using pure Python:

def typed_paths(treespec):
    def gen_type_paths(spec):
        if spec.is_leaf():
            yield ()
            return

        node_type = spec.type
        for entry, child_spec in zip(spec.entries(), spec.children()):
            for child_typed_path in gen_type_paths(child_spec):
                yield ((node_type, entry), *child_typed_path)

    return list(gen_type_paths(treespec))

def tree_flatten_with_typed_path(tree):
    leaves, treespec = optree.tree_flatten(tree)
    return typed_paths(treespec), leaves, treespec

We can define a member method for PyTreeSpec in C++ (metaopt/optree#108).

def typed_paths(treespec):
     stack = []
 
     def gen_type_paths(spec):
         if spec.is_leaf():
             yield tuple(stack)
             return
 
         node_type = spec.type
         for entry, child_spec in zip(spec.entries(), spec.children()):
             stack.append((node_type, entry))
             yield from gen_type_paths(child_spec)
             stack.pop()
 
     return list(gen_type_paths(treespec))
In [7]: typed_paths(treespec)
Out[7]:
[
    ((<class 'collections.OrderedDict'>, 'a'), (<class 'tuple'>, 0)),
    ((<class 'collections.OrderedDict'>, 'a'), (<class 'tuple'>, 1)),
    ((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'c'), (<class 'list'>, 0)),
    ((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'c'), (<class 'list'>, 1)),
    ((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'd')),
    ((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 0)),
    ((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 1)),
    ((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 2))
]

An alternative is to add a new function entry getitem_func in the type registration.

class PyTreeNodeRegistryEntry(NamedTuple):
    flatten_func: FlattenFunc
    unflatten_func: UnflattenFunc
    getitem_func: GetitemFunc

where the signatures are:

flatten_func(container) -> (children, metadata, entries)
unflatten_func(metadata, children) -> container
getitem_func(container, entry) -> child

and getitem_func defaults to lambda container, entry: container.__getitem__(entry).

@stroxler
Copy link
Contributor

Yes, I think that having the lookup keys tagged with the node type would probably suffice; it should be enough to resolve the question of, for example, how to know when to use a string key as an attribute name versus dict index.

@suo
Copy link
Member Author
suo commented Dec 21, 2023

Any progress on this?

suo added a commit that referenced this issue Jan 4, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 4, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

ghstack-source-id: 211428949
Pull Request resolved: #116786
suo added a commit that referenced this issue Jan 9, 2024
Pull Request resolved: #116786

This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
ghstack-source-id: 211692851
@exported-using-ghexport

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
suo added a commit that referenced this issue Jan 9, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 9, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 9, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 9, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 9, 2024
Pull Request resolved: #116786

This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
ghstack-source-id: 211696842
@exported-using-ghexport

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
suo added a commit that referenced this issue Jan 13, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 13, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 13, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 13, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 13, 2024
Pull Request resolved: #116786

This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
ghstack-source-id: 212052474
@exported-using-ghexport

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthed
9E88
ocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
Pull Request resolved: #116786

This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
ghstack-source-id: 212211794
@exported-using-ghexport

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)

[ghstack-poisoned]
suo added a commit that referenced this issue Jan 16, 2024
Pull Request resolved: #116786

This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
ghstack-source-id: 212229117
@exported-using-ghexport

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
pytorchmergebot pushed a commit that referenced this issue Jan 17, 2024
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
Pull Request resolved: #116786
Approved by: https://github.com/voznesenskym
@XuehaiPan XuehaiPan linked a pull request Jul 5, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: pytree triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
0