diff --git a/test/test_pytree.py b/test/test_pytree.py index 03cae00b4c7d0b..1fe4423a1775b9 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -43,18 +43,10 @@ def test_aligned_public_apis(self): cxx_signature = inspect.signature(cxx_api) py_signature = inspect.signature(py_api) - # The C++ pytree APIs provide more features than the Python APIs. - # The Python APIs are a subset of the C++ APIs. - # Check the signature of the Python API is a subset of the C++ API. + # Check the parameter names are the same. cxx_param_names = list(cxx_signature.parameters) py_param_names = list(py_signature.parameters) - self.assertTrue( - set(cxx_param_names).issuperset(py_param_names), - msg=( - f"C++ parameter(s) ({cxx_param_names}) " - f"not in Python parameter(s) ({py_param_names})" - ), - ) + self.assertEqual(cxx_param_names, py_param_names) # Check the positional parameters are the same. cxx_positional_param_names = [ diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index a25694fea50ae5..7cb19f36c19fb3 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -98,20 +98,9 @@ def register_pytree_node( serialized_type_name: Optional[str] = None, to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, - namespace: str = "torch", ) -> None: """Register a container-like type as pytree node. - The ``namespace`` argument is used to avoid collisions that occur when different libraries - register the same Python type with different behaviors. It is recommended to add a unique prefix - to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify - the same class in different namespaces for different use cases. - - .. warning:: - For safety reasons, a ``namespace`` must be specified while registering a custom type. It is - used to isolate the behavior of flattening and unflattening a pytree node type. This is to - prevent accidental collisions between different libraries that may register the same type. - Args: cls (type): A Python type to treat as an internal pytree node. flatten_fn (callable): A function to be used during flattening, taking an instance of @@ -130,9 +119,6 @@ def register_pytree_node( how to convert the custom json dumpable representation of the context back to the original context. This is used for json deserialization, which is being used in :mod:`torch.export` right now. - namespace (str, optional): A non-empty string that uniquely identifies the namespace of the - type registry. This is used to isolate the registry from other modules that might - register a different custom behavior for the same type. (default: :const:`"torch"`) Example:: @@ -142,77 +128,7 @@ def register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda children, _: set(children), - ... namespace='set', ... ) - - >>> # xdoctest: +SKIP - >>> # Register a Python type into a namespace - >>> import torch - >>> register_pytree_node( - ... torch.Tensor, - ... flatten_func=lambda tensor: ( - ... (tensor.cpu().detach().numpy(),), - ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, - ... ), - ... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata), - ... namespace='torch2numpy', - ... ) - - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) - >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} - >>> tree - {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} - - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) - >>> # Flatten without specifying the namespace - >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP - ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) - - >>> # xdoctest: +SKIP - >>> # Flatten with the namespace - >>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP - ( - [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], - PyTreeSpec( - { - 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]), - 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]) - }, - namespace='torch2numpy' - ) - ) - - >>> # xdoctest: +SKIP - >>> # Register the same type with a different namespace for different behaviors - >>> def tensor2flatparam(tensor): - ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None - ... - >>> def flatparam2tensor(children, metadata): - ... return children[0].reshape(metadata) - ... - >>> register_pytree_node( - ... torch.Tensor, - ... flatten_func=tensor2flatparam, - ... unflatten_func=flatparam2tensor, - ... namespace='tensor2flatparam', - ... ) - - >>> # xdoctest: +SKIP - >>> # Flatten with the new namespace - >>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP - ( - [ - Parameter containing: tensor([0., 0.], requires_grad=True), - Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) - ], - PyTreeSpec( - { - 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), - 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) - }, - namespace='tensor2flatparam' - ) - ) """ _private_register_pytree_node( cls, @@ -221,7 +137,6 @@ def register_pytree_node( serialized_type_name=serialized_type_name, to_dumpable_context=to_dumpable_context, from_dumpable_context=from_dumpable_context, - namespace=namespace, ) from . import _pytree as python @@ -244,7 +159,6 @@ def _register_pytree_node( serialized_type_name: Optional[str] = None, to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, - namespace: str = "torch", ) -> None: """Register a container-like type as pytree node for the C++ pytree only. @@ -276,89 +190,6 @@ def _register_pytree_node( how to convert the custom json dumpable representation of the context back to the original context. This is used for json deserialization, which is being used in :mod:`torch.export` right now. - namespace (str, optional): A non-empty string that uniquely identifies the namespace of the - type registry. This is used to isolate the registry from other modules that might - register a different custom behavior for the same type. (default: :const:`"torch"`) - - Example:: - - >>> # xdoctest: +SKIP - >>> # Registry a Python type with lambda functions - >>> register_pytree_node( - ... set, - ... lambda s: (sorted(s), None, None), - ... lambda children, _: set(children), - ... namespace='set', - ... ) - - >>> # xdoctest: +SKIP - >>> # Register a Python type into a namespace - >>> import torch - >>> register_pytree_node( - ... torch.Tensor, - ... flatten_func=lambda tensor: ( - ... (tensor.cpu().detach().numpy(),), - ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, - ... ), - ... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata), - ... namespace='torch2numpy', - ... ) - - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) - >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} - >>> tree - {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} - - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) - >>> # Flatten without specifying the namespace - >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP - ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) - - >>> # xdoctest: +SKIP - >>> # Flatten with the namespace - >>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP - ( - [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], - PyTreeSpec( - { - 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]), - 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]) - }, - namespace='torch2numpy' - ) - ) - - >>> # xdoctest: +SKIP - >>> # Register the same type with a different namespace for different behaviors - >>> def tensor2flatparam(tensor): - ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None - ... - >>> def flatparam2tensor(children, metadata): - ... return children[0].reshape(metadata) - ... - >>> register_pytree_node( - ... torch.Tensor, - ... flatten_func=tensor2flatparam, - ... unflatten_func=flatparam2tensor, - ... namespace='tensor2flatparam', - ... ) - - >>> # xdoctest: +SKIP - >>> # Flatten with the new namespace - >>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP - ( - [ - Parameter containing: tensor([0., 0.], requires_grad=True), - Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) - ], - PyTreeSpec( - { - 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), - 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) - }, - namespace='tensor2flatparam' - ) - ) """ warnings.warn( "torch.utils._cxx_pytree._register_pytree_node is deprecated. " @@ -373,7 +204,6 @@ def _register_pytree_node( serialized_type_name=serialized_type_name, to_dumpable_context=to_dumpable_context, from_dumpable_context=from_dumpable_context, - namespace=namespace, ) @@ -385,7 +215,6 @@ def _private_register_pytree_node( serialized_type_name: Optional[str] = None, to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, - namespace: str = "torch", ) -> None: """This is an internal function that is used to register a pytree node type for the C++ pytree only. End-users should use :func:`register_pytree_node` @@ -398,16 +227,11 @@ def _private_register_pytree_node( cls, flatten_fn, _reverse_args(unflatten_fn), - namespace=namespace, + namespace="torch", ) -def tree_flatten( - tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> Tuple[List[Any], TreeSpec]: +def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]: """Flatten a pytree. See also :func:`tree_unflatten`. @@ -418,14 +242,10 @@ def tree_flatten( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_flatten(tree) ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)) - >>> tree_flatten(tree, none_is_leaf=False) - ([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})) >>> tree_flatten(1) ([1], PyTreeSpec(*, NoneIsLeaf)) >>> tree_flatten(None) ([None], PyTreeSpec(*, NoneIsLeaf)) - >>> tree_flatten(None, none_is_leaf=False) - ([], PyTreeSpec(None)) For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` @@ -435,16 +255,9 @@ def tree_flatten( >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) >>> tree_flatten(tree) ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)) - >>> tree_flatten(tree, none_is_leaf=False) - ([2, 3, 4, 1, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))) Args: tree (pytree): A pytree to flatten. - none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, - :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the - treespec rather than in the leaves list. (default: :data:`True`) - namespace (str, optional): The registry namespace used for custom pytree node types. - (default: :const:`"torch"`) Returns: A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the @@ -452,8 +265,8 @@ def tree_flatten( """ return optree.tree_flatten( # type: ignore[return-value] tree, - none_is_leaf=none_is_leaf, - namespace=namespace, + none_is_leaf=True, + namespace="torch", ) @@ -484,12 +297,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type] -def tree_leaves( - tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> List[Any]: +def tree_leaves(tree: PyTree) -> List[Any]: """Get the leaves of a pytree. See also :func:`tree_flatten`. @@ -497,35 +305,21 @@ def tree_leaves( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_leaves(tree) [1, 2, 3, 4, None, 5] - >>> tree_leaves(tree, none_is_leaf=False) - [1, 2, 3, 4, 5] >>> tree_leaves(1) [1] >>> tree_leaves(None) [None] - >>> tree_leaves(None, none_is_leaf=False) - [] Args: tree (pytree): A pytree to flatten. - none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, - :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the - treespec rather than in the leaves list. (default: :data:`True`) - namespace (str, optional): The registry namespace used for custom pytree node types. - (default: :const:`"torch"`) Returns: A list of leaf values. """ - return optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) + return optree.tree_leaves(tree, none_is_leaf=True, namespace="torch") -def tree_structure( - tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> TreeSpec: +def tree_structure(tree: PyTree) -> TreeSpec: """Get the treespec for a pytree. See also :func:`tree_flatten`. @@ -533,41 +327,26 @@ def tree_structure( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_structure(tree) PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) - >>> tree_structure(tree, none_is_leaf=False) - PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) >>> tree_structure(1) PyTreeSpec(*, NoneIsLeaf) >>> tree_structure(None) PyTreeSpec(*, NoneIsLeaf) - >>> tree_structure(None, none_is_leaf=False) - PyTreeSpec(None) Args: tree (pytree): A pytree to flatten. - none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, - :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the - treespec rather than in the leaves list. (default: :data:`True`) - namespace (str, optional): The registry namespace used for custom pytree node types. - (default: :const:`"torch"`) Returns: A treespec object representing the structure of the pytree. """ return optree.tree_structure( # type: ignore[return-value] tree, - none_is_leaf=none_is_leaf, - namespace=namespace, + none_is_leaf=True, + namespace="torch", ) -def tree_map( - func: Callable[..., Any], - tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> PyTree: - """Map a multi-input function over pytree args to produce a new pytree. +def tree_map(func: Callable[..., Any], tree: PyTree) -> PyTree: + """Map a function over leaves in a pytree to produce a new pytree. See also :func:`tree_map_`. @@ -575,79 +354,46 @@ def tree_map( {'x': 8, 'y': (43, 65)} >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) {'x': False, 'y': (False, False), 'z': True} - >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=False) - {'x': 8, 'y': (43, 65), 'z': None} - >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=False) - {'x': False, 'y': (False, False), 'z': None} - - If multiple inputs are given, the structure of the tree is taken from the first input; - subsequent inputs need only have ``tree`` as a prefix: - - >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) - [[5, 7, 9], [6, 1, 2]] Args: - func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the - corresponding leaves of the pytrees. - tree (pytree): A pytree to be mapped over, with each leaf providing the first positional - argument to function ``func``. - rests (tuple of pytrees): A tuple of pytrees, each of which has the same structure as - ``tree`` or has ``tree`` as a prefix. - none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, - :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the - treespec rather than in the leaves list. (default: :data:`True`) - namespace (str, optional): The registry namespace used for custom pytree node types. - (default: :const:`"torch"`) + func (callable): A function that takes a single argument, to be applied at the corresponding + leaves of the pytree. + tree (pytree): A pytree to be mapped over, with each leaf providing the argument to function + ``func``. Returns: A new pytree with the same structure as ``tree`` but with the value at each leaf given by - ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` - is the tuple of values at corresponding nodes in ``rests``. + ``func(x)`` where ``x`` is the value at the corresponding leaf in ``tree``. """ return optree.tree_map( func, tree, - *rests, - none_is_leaf=none_is_leaf, - namespace=namespace, + none_is_leaf=True, + namespace="torch", ) -def tree_map_( - func: Callable[..., Any], - tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> PyTree: +def tree_map_(func: Callable[..., Any], tree: PyTree) -> PyTree: """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. See also :func:`tree_map`. Args: - func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the - corresponding leaves of the pytrees. - tree (pytree): A pytree to be mapped over, with each leaf providing the first positional - argument to function ``func``. - rests (tuple of pytrees): A tuple of pytrees, each of which has the same structure as - ``tree`` or has ``tree`` as a prefix. - none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, - :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the - treespec rather than in the leaves list. (default: :data:`True`) - namespace (str, optional): The registry namespace used for custom pytree node types. - (default: :const:`"torch"`) + func (callable): A function that takes a single argument, to be applied at the corresponding + leaves of the pytree. + tree (pytree): A pytree to be mapped over, with each leaf providing the argument to function + ``func``. Returns: The original ``tree`` with the value at each leaf is given by the side-effect of function - ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf - in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + ``func(x)`` (not the return value) where ``x`` is the value at the corresponding leaf in + ``tree``. """ return optree.tree_map_( func, tree, - *rests, - none_is_leaf=none_is_leaf, - namespace=namespace, + none_is_leaf=True, + namespace="torch", ) @@ -723,9 +469,6 @@ def tree_map_only( __type_or_types: Type[T], func: Fn[T, Any], tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> PyTree: ... @@ -735,9 +478,6 @@ def tree_map_only( __type_or_types: Type2[T, S], func: Fn2[T, S, Any], tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> PyTree: ... @@ -747,9 +487,6 @@ def tree_map_only( __type_or_types: Type3[T, S, U], func: Fn3[T, S, U, Any], tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> PyTree: ... @@ -758,17 +495,8 @@ def tree_map_only( __type_or_types: TypeAny, func: FnAny[Any], tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> PyTree: - return tree_map( - map_only(__type_or_types)(func), - tree, - *rests, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + return tree_map(map_only(__type_or_types)(func), tree) @overload @@ -776,9 +504,6 @@ def tree_map_only_( __type_or_types: Type[T], func: Fn[T, Any], tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> PyTree: ... @@ -788,9 +513,6 @@ def tree_map_only_( __type_or_types: Type2[T, S], func: Fn2[T, S, Any], tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> PyTree: ... @@ -800,9 +522,6 @@ def tree_map_only_( __type_or_types: Type3[T, S, U], func: Fn3[T, S, U, Any], tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> PyTree: ... @@ -811,38 +530,17 @@ def tree_map_only_( __type_or_types: TypeAny, func: FnAny[Any], tree: PyTree, - *rests: PyTree, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> PyTree: - return tree_map_( - map_only(__type_or_types)(func), - tree, - *rests, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + return tree_map_(map_only(__type_or_types)(func), tree) -def tree_all( - pred: Callable[[Any], bool], - tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> bool: - flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) +def tree_all(pred: Callable[[Any], bool], tree: PyTree) -> bool: + flat_args = tree_leaves(tree) return all(map(pred, flat_args)) -def tree_any( - pred: Callable[[Any], bool], - tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> bool: - flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) +def tree_any(pred: Callable[[Any], bool], tree: PyTree) -> bool: + flat_args = tree_leaves(tree) return any(map(pred, flat_args)) @@ -851,9 +549,6 @@ def tree_all_only( __type_or_types: Type[T], pred: Fn[T, bool], tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> bool: ... @@ -863,9 +558,6 @@ def tree_all_only( __type_or_types: Type2[T, S], pred: Fn2[T, S, bool], tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> bool: ... @@ -875,9 +567,6 @@ def tree_all_only( __type_or_types: Type3[T, S, U], pred: Fn3[T, S, U, bool], tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> bool: ... @@ -886,11 +575,8 @@ def tree_all_only( __type_or_types: TypeAny, pred: FnAny[bool], tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> bool: - flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) + flat_args = tree_leaves(tree) return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) @@ -899,9 +585,6 @@ def tree_any_only( __type_or_types: Type[T], pred: Fn[T, bool], tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> bool: ... @@ -911,9 +594,6 @@ def tree_any_only( __type_or_types: Type2[T, S], pred: Fn2[T, S, bool], tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> bool: ... @@ -923,9 +603,6 @@ def tree_any_only( __type_or_types: Type3[T, S, U], pred: Fn3[T, S, U, bool], tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> bool: ... @@ -934,21 +611,12 @@ def tree_any_only( __type_or_types: TypeAny, pred: FnAny[bool], tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", ) -> bool: - flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) + flat_args = tree_leaves(tree) return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) -def broadcast_prefix( - prefix_tree: PyTree, - full_tree: PyTree, - *, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> List[Any]: +def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]: """Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``. If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be @@ -970,8 +638,6 @@ def broadcast_prefix( [1, 2, 3, 3] >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) [1, 2, 3, 3, 3, 3] - >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=False) - [1, 2, 3, 3, 3] Args: prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``. @@ -980,11 +646,6 @@ def broadcast_prefix( flattening step. It should return a boolean, with :data:`True` stopping the traversal and the whole subtree being treated as a leaf, and :data:`False` indicating the flattening should traverse the current object. - none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, - :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the - treespec rather than in the leaves list. (default: :data:`True`) - namespace (str, optional): The registry namespace used for custom pytree node types. - (default: :const:`"torch"`) Returns: A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``. @@ -992,8 +653,8 @@ def broadcast_prefix( return optree.broadcast_prefix( prefix_tree, full_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, + none_is_leaf=True, + namespace="torch", ) @@ -1005,22 +666,11 @@ def broadcast_prefix( # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be # broadcastable to the tree structure of `inputs` and we use # _broadcast_to_and_flatten to check this. -def _broadcast_to_and_flatten( - tree: PyTree, - treespec: TreeSpec, - *, - none_is_leaf: bool = True, - namespace: str = "torch", -) -> Optional[List[Any]]: +def _broadcast_to_and_flatten(tree: PyTree, treespec: TreeSpec) -> Optional[List[Any]]: assert isinstance(treespec, TreeSpec) full_tree = tree_unflatten([0] * treespec.num_leaves, treespec) try: - return broadcast_prefix( - tree, - full_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + return broadcast_prefix(tree, full_tree) except ValueError: return None @@ -1073,5 +723,5 @@ def __instancecheck__(self, instance: object) -> bool: class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): - def __new__(cls, none_is_leaf: bool = True) -> "LeafSpec": - return optree.treespec_leaf(none_is_leaf=none_is_leaf) # type: ignore[return-value] + def __new__(cls) -> "LeafSpec": + return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]