8000 [dynamo][pytree][3/N] make CXX pytree traceable: `tree_map` / `tree_m… · XuehaiPan/pytorch@7798a73 · GitHub
  • [go: up one dir, main page]

    Skip to content

    Commit 7798a73

    Browse files
    committed
    [dynamo][pytree][3/N] make CXX pytree traceable: tree_map / tree_map_ (pytorch#137399)
    Pull Request resolved: pytorch#137399 Approved by: https://github.com/jansel ghstack dependencies: pytorch#137398
    1 parent 1f90567 commit 7798a73

    File tree

    4 files changed

    +220
    -89
    lines changed

    4 files changed

    +220
    -89
    lines changed

    test/dynamo/test_misc.py

    Lines changed: 2 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -10153,6 +10153,8 @@ def fn(x, y):
    1015310153

    1015410154
    def test_pytree_tree_map(self):
    1015510155
    implemtations = [("python", python_pytree)]
    10156+
    if cxx_pytree is not None:
    10157+
    implemtations.append(("cxx", cxx_pytree))
    1015610158

    1015710159
    for name, module in implemtations:
    1015810160
    with self.subTest(f"pytree implement: {name}"):

    torch/_dynamo/polyfills/pytree.py

    Lines changed: 118 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -4,12 +4,13 @@
    44

    55
    from __future__ import annotations
    66

    7+
    from collections import deque
    78
    from dataclasses import dataclass, field
    89
    from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING
    910
    from typing_extensions import TypeIs
    1011

    1112
    import torch.utils._pytree as python_pytree
    12-
    from torch.utils._pytree import BUILTIN_TYPES
    13+
    from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
    1314

    1415
    from ..decorators import substitute_in_graph
    1516

    @@ -200,6 +201,95 @@ def entries(self) -> list[Any]:
    200201
    def entry(self, index: int) -> Any:
    201202
    return self._entries[index]
    202203

    204+
    def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
    205+
    def helper(
    206+
    treespec: PyTreeSpec,
    207+
    node: PyTree,
    208+
    subtrees: list[PyTree],
    209+
    ) -> None:
    210+
    if treespec.is_leaf():
    211+
    subtrees.append(node)
    212+
    return
    213+
    214+
    node_type = type(node)
    215+
    if treespec.type not in BUILTIN_TYPES:
    216+
    # Always require custom node types to match exactly
    217+
    if node_type != treespec.type:
    218+
    raise ValueError(
    219+
    f"Type mismatch; "
    220+
    f"expected {treespec.type!r}, but got {node_type!r}.",
    221+
    )
    222+
    223+
    children, metadata, *_ = optree.tree_flatten_one_level(
    224+
    node,
    225+
    none_is_leaf=True,
    226+
    namespace="torch",
    227+
    )
    228+
    if len(children) != treespec.num_children:
    229+
    raise ValueError(
    230+
    f"Node arity mismatch; "
    231+
    f"expected {treespec.num_children}, but got {len(children)}.",
    232+
    )
    233+
    if metadata != treespec._metadata:
    234+
    raise ValueError(
    235+
    f"Node context mismatch for custom node type {treespec.type!r}.",
    236+
    )
    237+
    else:
    238+
    # For builtin dictionary types, we allow some flexibility
    239+
    # Otherwise, we require exact matches
    240+
    both_standard_dict = (
    241+
    treespec.type in STANDARD_DICT_TYPES
    242+
    and node_type in STANDARD_DICT_TYPES
    243+
    )
    244+
    if not both_standard_dict and node_type != treespec.type:
    245+
    raise ValueError(
    246+
    f"Node type mismatch; "
    247+
    f"expected {treespec.type!r}, but got {node_type!r}.",
    248+
    )
    249+
    if len(node) != treespec.num_children:
    250+
    raise ValueError(
    251+
    f"Node arity mismatch; "
    252+
    f"expected {treespec.num_children}, but got {len(node)}.",
    253+
    )
    254+
    255+
    if both_standard_dict:
    256+
    # dictionary types are compatible with each other
    257+
    expected_keys = treespec.entries()
    258+
    got_key_set = set(node)
    259+
    expected_key_set = set(expected_keys)
    260+
    if got_key_set != expected_key_set:
    261+
    missing_keys = expected_key_set.difference(got_key_set)
    262+
    extra_keys = got_key_set.difference(expected_key_set)
    263+
    message = ""
    264+
    if missing_keys:
    265+
    message += f"; missing key(s): {missing_keys}"
    266+
    if extra_keys:
    267+
    message += f"; extra key(s): {extra_keys}"
    268+
    raise ValueError(f"Node keys mismatch{message}.")
    269+
    children = [node[key] for key in expected_keys]
    270+
    else:
    271+
    # node_type is treespec.type
    272+
    children, metadata, *_ = optree.tree_flatten_one_level(
    273+
    node,
    274+
    none_is_leaf=True,
    275+
    namespace="torch",
    276+
    )
    277+
    if (
    278+
    node_type
    279+
    is not deque # ignore mismatch of `maxlen` for deque
    280+
    ) and metadata != treespec._metadata:
    281+
    raise ValueError(
    282+
    f"Node metadata mismatch for node type {treespec.type!r}; "
    283+
    f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
    284+
    )
    285+
    286+
    for subtree, subspec in zip(children, treespec._children):
    287+
    helper(subspec, subtree, subtrees)
    288+
    289+
    subtrees: list[PyTree] = []
    290+
    helper(self, tree, subtrees)
    291+
    return subtrees
    292+
    203293
    def unflatten(self, leaves: Iterable[Any]) -> PyTree:
    204294
    if not isinstance(leaves, (list, tuple)):
    205295
    leaves = list(leaves)
    @@ -295,3 +385,30 @@ def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
    295385
    return treespec.unflatten(leaves)
    296386

    297387
    __all__ += ["tree_unflatten"]
    388+
    389+
    @substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
    390+
    def tree_map(
    391+
    func: Callable[..., Any],
    392+
    tree: PyTree,
    393+
    *rests: PyTree,
    394+
    is_leaf: Callable[[PyTree], bool] | None = None,
    395+
    ) -> PyTree:
    396+
    leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
    397+
    flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
    398+
    return treespec.unflatten(map(func, *flat_args))
    399+
    400+
    __all__ += ["tree_map"]
    401+
    402+
    @substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
    403+
    def tree_map_(
    404+
    func: Callable[..., Any],
    405+
    tree: PyTree,
    406+
    *rests: PyTree,
    407+
    is_leaf: Callable[[PyTree], bool] | None = None,
    408+
    ) -> PyTree:
    409+
    leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
    410+
    flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
    411+
    deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
    412+
    return tree
    413+
    414+
    __all__ += ["tree_map_"]

    torch/utils/_cxx_pytree.py

    Lines changed: 25 additions & 15 deletions
    Original file line numberDiff line numberDiff line change
    @@ -293,7 +293,7 @@ def tree_flatten(
    293293
    The flattening order (i.e., the order of elements in the output list) is deterministic,
    294294
    corresponding to a left-to-right depth-first tree traversal.
    295295
    296-
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
    296+
    >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
    297297
    >>> tree_flatten(tree)
    298298
    ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
    299299
    >>> tree_flatten(1)
    @@ -306,7 +306,7 @@ def tree_flatten(
    306306
    if you want to keep the keys in the insertion order.
    307307
    308308
    >>> from collections import OrderedDict
    309-
    >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
    309+
    >>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
    310310
    >>> tree_flatten(tree)
    311311
    ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
    312312
    @@ -335,7 +335,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
    335335
    336336
    The inverse of :func:`tree_flatten`.
    337337
    338-
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
    338+
    >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
    339339
    >>> leaves, treespec = tree_flatten(tree)
    340340
    >>> tree == tree_unflatten(leaves, treespec)
    341341
    True
    @@ -365,7 +365,7 @@ def tree_iter(
    365365
    366366
    See also :func:`tree_flatten`.
    367367
    368-
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
    368+
    >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
    369369
    >>> list(tree_iter(tree))
    370370
    [1, 2, 3, 4, None, 5]
    371371
    >>> list(tree_iter(1))
    @@ -400,7 +400,7 @@ def tree_leaves(
    400400
    401401
    See also :func:`tree_flatten`.
    402402
    403-
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
    403+
    >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
    404404
    >>> tree_leaves(tree)
    405405
    [1, 2, 3, 4, None, 5]
    406406
    >>> tree_leaves(1)
    @@ -435,7 +435,7 @@ def tree_structure(
    435435
    436436
    See also :func:`tree_flatten`.
    437437
    438-
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
    438+
    >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
    439439
    >>> tree_structure(tree)
    440440
    PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
    441441
    >>> tree_structure(1)
    @@ -472,9 +472,9 @@ def tree_map(
    472472
    473473
    See also :func:`tree_map_`.
    474474
    475-
    >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
    475+
    >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
    476476
    {'x': 8, 'y': (43, 65)}
    477-
    >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
    477+
    >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
    478478
    {'x': False, 'y': (False, False), 'z': True}
    479479
    480480
    If multiple inputs are given, the structure of the tree is taken from the first input;
    @@ -572,7 +572,9 @@ def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
    572572

    573573

    574574
    @overload
    575-
    def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
    575+
    def map_only(
    576+
    __type_or_types_or_pred: Type3[T, S, U],
    577+
    ) -> MapOnlyFn[Fn3[T, S, U, Any]]:
    576578
    ...
    577579

    578580

    @@ -588,12 +590,14 @@ def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
    588590

    589591

    590592
    @overload
    591-
    def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
    593+
    def map_only(
    594+
    __type_or_types_or_pred: Callable[[Any], bool],
    595+
    ) -> MapOnlyFn[FnAny[Any]]:
    592596
    ...
    593597

    594598

    595599
    def map_only(
    596-
    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
    600+
    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
    597601
    ) -> MapOnlyFn[FnAny[Any]]:
    598602
    """
    599603
    Suppose you are writing a tree_map over tensors, leaving everything
    @@ -858,7 +862,7 @@ def broadcast_prefix(
    858862
    ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
    859863
    >>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
    860864
    [1, 2, 3, 3]
    861-
    >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
    865+
    >>> broadcast_prefix([1, 2, 3], [1, 2, {"a": 3, "b": 4, "c": (None, 5)}])
    862866
    [1, 2, 3, 3, 3, 3]
    863867
    864868
    Args:
    @@ -873,13 +877,19 @@ def broadcast_prefix(
    873877
    Returns:
    874878
    A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
    875879
    """
    876-
    return optree.broadcast_prefix(
    880+
    result: List[Any] = []
    881+
    882+
    def add_leaves(x: Any, subtree: PyTree) -> None:
    883+
    subtreespec = tree_structure(subtree, is_leaf=is_leaf)
    884+
    result.extend([x] * subtreespec.num_leaves)
    885+
    886+
    tree_map_(
    887+
    add_leaves,
    877888
    prefix_tree,
    878889
    full_tree,
    879890
    is_leaf=is_leaf,
    880-
    none_is_leaf=True,
    881-
    namespace="torch",
    882891
    )
    892+
    return result
    883893

    884894

    885895
    # Broadcasts a pytree to the provided TreeSpec and returns the flattened

    0 commit comments

    Comments
     (0)
    0