-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[pytree] implement key path API #113378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and
privacy statement. We’ll occasionally send you account related emails.
Already on GitHub?
Sign in
to your account
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Signed up this as HAM project for one of our internal transfer candidate. |
@XuehaiPan Curious your thoughts on this. It looks like optree already has the main operations we want here: If we're hoping to use optree as the backend for pytorch's pytree, would it make sense to focus on that first? |
The key path APIs in JAX pytree and In [1]: from collections import *
In [2]: MyTuple = namedtuple('MyTuple', ['x', 'y', 'z'])
In [3]: tree = OrderedDict([('a', (1, 2)), ('b', {'c': [3, 4], 'd': 5}), ('e', MyTuple(6, 7, 8))]) In [4]: import jax.tree_util as jaxtree
In [5]: jaxtree.tree_flatten_with_path(tree)
Out[5]:
(
[
((DictKey(key='a'), SequenceKey(idx=0)), 1),
((DictKey(key='a'), SequenceKey(idx=1)), 2),
((DictKey(key='b'), DictKey(key='c'), SequenceKey(idx=0)), 3),
((DictKey(key='b'), DictKey(key='c'), SequenceKey(idx=1)), 4),
((DictKey(key='b'), DictKey(key='d')), 5),
((DictKey(key='e'), GetAttrKey(name='x')), 6),
((DictKey(key='e'), GetAttrKey(name='y')), 7),
((DictKey(key='e'), GetAttrKey(name='z')), 8)
],
PyTreeDef(CustomNode(OrderedDict[('a', 'b', 'e')], [(*, *), {'c': [*, *], 'd': *}, CustomNode(namedtuple[MyTuple], [*, *, *])]))
) In [6]: import optree
In [7]: optree.tree_flatten_with_path(tree)
Out[7]:
(
[('a', 0), ('a', 1), ('b', 'c', 0), ('b', 'c', 1), ('b', 'd'), ('e', 0), ('e', 1), ('e', 2)],
[1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(OrderedDict([('a', (*, *)), ('b', {'c': [*, *], 'd': *}), ('e', MyTuple(x=*, y=*, z=*))]))
) In JAX's pytree, leaf nodes are identified using a tuple of KeyPath objects, including types like In [8]: jaxtree.tree_flatten_with_path(tree)[0][0][0]
Out[8]: (DictKey(key='a'), SequenceKey(idx=0))
In [9]: optree.tree_flatten_with_path(tree)[0][0]
Out[9]: ('a', 0)
In [10]: tree['a'][0]
Out[10]: 1 Most container classes implement the In [11]: class itemgetter:
...: def __init__(self, path):
...: self.path = tuple(path)
...:
...: def __call__(self, obj):
...: for p in self.path:
...: obj = obj[p]
...: return obj
In [12]: getter = itemgetter(('a', 0))
In [13]: getter(tree)
Out[13]: 1
In [14]: paths = optree.tree_paths(tree)
In [15]: paths
Out[15]:
[
('a', 0),
('a', 1),
('b', 'c', 0),
('b', 'c', 1),
('b', 'd'),
('e', 0),
('e', 1),
('e', 2)
]
In [16]: getters = [itemgetter(path) for path in paths]
In [17]: [getter(tree) for getter in getters]
Out[17]: [1, 2, 3, 4, 5, 6, 7, 8] One thing I want to point out is that the JAX pytree uses In [18]: jaxtree.tree_flatten_with_path(MyTuple(1, 2, 3))
Out[18]:
(
[
((GetAttrKey(name='x'),), 1),
((GetAttrKey(name='y'),), 2),
((GetAttrKey(name='z'),), 3)
],
PyTreeDef(CustomNode(namedtuple[MyTuple], [*, *, *]))
)
In [19]: optree.tree_flatten_with_path(MyTuple(1, 2, 3))
Out[19]:
(
[(0,), (1,), (2,)],
[1, 2, 3],
PyTreeSpec(MyTuple(x=*, y=*, z=*))
) Any thoughts about the API design for the key paths? 8000 |
Thanks for the detailed answer! Personally I find the Jax approach of wrapping the key in a semantic type a bit more future proof. It's probably partly my preferences as someone who works mainly on a type-checker, but encoding the meaning of the key (in particular so that attribute keys - which are necessary to, for example, add support for some kinds of dataclasses - can be distinguished from map index keys) generally means the types are more expressive, and the system is easier to extend to new cases. If you prefer optree's current API, I suspect it wouldn't be all that hard to add a wrapper layer for this |
@stroxler I agree that the keys with semantic types in JAX are more expressive. In
Would adding the node type in the path result resolve your concern? For example:
I think this would be more general than the semantic key in JAX. For example, we can distinguish the node type of In [1]: from collections import *
In [2]: MyTuple = namedtuple('MyTuple', ['x', 'y', 'z'])
In [3]: tree = OrderedDict([('a', (1, 2)), ('b', {'c': [3, 4], 'd': 5}), ('e', MyTuple(6, 7, 8))])
In [4]: import optree
In [5]: optree.tree_flatten_with_path(tree)
Out[5]:
(
[('a', 0), ('a', 1), ('b', 'c', 0), ('b', 'c', 1), ('b', 'd'), ('e', 0), ('e', 1), ('e', 2)],
[1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(OrderedDict([('a', (*, *)), ('b', {'c': [*, *], 'd': *}), ('e', MyTuple(x=*, y=*, z=*))]))
)
In [6]: optree.tree_flatten_with_typed_path(tree)
Out[6]:
(
[
((<class 'collections.OrderedDict'>, 'a'), (<class 'tuple'>, 0)),
((<class 'collections.OrderedDict'>, 'a'), (<class 'tuple'>, 1)),
((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'c'), (<class 'list'>, 0)),
((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'c'), (<class 'list'>, 1)),
((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'd')),
((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 0)),
((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 1)),
((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 2))
],
[1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(OrderedDict([('a', (*, *)), ('b', {'c': [*, *], 'd': *}), ('e', MyTuple(x=*, y=*, z=*))]))
) An initial implementation for this using pure Python: def typed_paths(treespec):
def gen_type_paths(spec):
if spec.is_leaf():
yield ()
return
node_type = spec.type
for entry, child_spec in zip(spec.entries(), spec.children()):
for child_typed_path in gen_type_paths(child_spec):
yield ((node_type, entry), *child_typed_path)
return list(gen_type_paths(treespec))
def tree_flatten_with_typed_path(tree):
leaves, treespec = optree.tree_flatten(tree)
return typed_paths(treespec), leaves, treespec We can define a member method for def typed_paths(treespec):
stack = []
def gen_type_paths(spec):
if spec.is_leaf():
yield tuple(stack)
return
node_type = spec.type
for entry, child_spec in zip(spec.entries(), spec.children()):
stack.append((node_type, entry))
yield from gen_type_paths(child_spec)
stack.pop()
return list(gen_type_paths(treespec)) In [7]: typed_paths(treespec)
Out[7]:
[
((<class 'collections.OrderedDict'>, 'a'), (<class 'tuple'>, 0)),
((<class 'collections.OrderedDict'>, 'a'), (<class 'tuple'>, 1)),
((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'c'), (<class 'list'>, 0)),
((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'c'), (<class 'list'>, 1)),
((<class 'collections.OrderedDict'>, 'b'), (<class 'dict'>, 'd')),
((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 0)),
((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 1)),
((<class 'collections.OrderedDict'>, 'e'), (<class '__main__.MyTuple'>, 2))
] An alternative is to add a new function entry class PyTreeNodeRegistryEntry(NamedTuple):
flatten_func: FlattenFunc
unflatten_func: UnflattenFunc
getitem_func: GetitemFunc where the signatures are: flatten_func(container) -> (children, metadata, entries)
unflatten_func(metadata, children) -> container
getitem_func(container, entry) -> child and |
Yes, I think that having the lookup keys tagged with the node type would probably suffice; it should be enough to resolve the question of, for example, how to know when to use a string key as an attribute name versus dict index. |
Any progress on this? |
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) ghstack-source-id: 211428949 Pull Request resolved: #116786
Pull Request resolved: #116786 This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. ghstack-source-id: 211692851 @exported-using-ghexport Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
Pull Request resolved: #116786 This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. ghstack-source-id: 211696842 @exported-using-ghexport Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
Pull Request resolved: #116786 This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. ghstack-source-id: 212052474 @exported-using-ghexport Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthed 9E88 ocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
Pull Request resolved: #116786 This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. ghstack-source-id: 212211794 @exported-using-ghexport Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) [ghstack-poisoned]
Pull Request resolved: #116786 This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. ghstack-source-id: 212229117 @exported-using-ghexport Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths). I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree. Current use cases for this API: - Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely. - In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo. I'm sure there are places it would be useful. Some design notes: - I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see #113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately. - The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy. - My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry. Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/) Pull Request resolved: #116786 Approved by: https://github.com/voznesenskym
Uh oh!
There was an error while loading. Please reload this page.
JAX has a key path API which makes it easy to index into a pytree. This would definitely be useful for more advanced users manipulation Pytrees; we have at least one internal use case for it already.
cc @zou3519
The text was updated successfully, but these errors were encountered: