8000 [Dynamo][pytree] handle `isinstance(...)` check for polyfilled class · pytorch/pytorch@e1741e3 · GitHub
[go: up one dir, main page]

Skip to content

Commit e1741e3

Browse files
committed
[Dynamo][pytree] handle isinstance(...) check for polyfilled class
ghstack-source-id: c2486d4 Pull Request resolved: #146921
1 parent 48ca6b9 commit e1741e3

File tree

4 files changed

+65
-6
lines changed
  • variables
  • utils
  • 4 files changed

    +65
    -6
    lines changed

    test/dynamo/test_misc.py

    Lines changed: 32 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -10250,6 +10250,38 @@ def fn(x, y):
    1025010250

    1025110251
    self.assertEqual(actual, expected)
    1025210252

    10253+
    @unittest.skipIf(cxx_pytree is None, "Test for C++ pytree polyfill infra")
    10254+
    def test_pytreespec_isinstance_check(self):
    10255+
    from tor 8000 ch._dynamo.polyfills import pytree as polyfilled_cxx_pytree
    10256+
    10257+
    @torch.compile(fullgraph=True)
    10258+
    def fn(x, y):
    10259+
    leaves, treespec = cxx_pytree.tree_flatten(x)
    10260+
    return leaves, treespec, y.sin()
    10261+
    10262+
    y = torch.randn(3)
    10263+
    x = [1, [2, [3, 4]]]
    10264+
    leaves, treespec, _ = fn(x, y)
    10265+
    # Compiled function returns an instance of the polyfilled class instead of the original class
    10266+
    self.assertIsInstance(treespec, polyfilled_cxx_pytree.PyTreeSpec)
    10267+
    # Must not raise exceptions that allow partially compiled programs to mix polyfilled classes
    10268+
    # with original classes in different parts of the program
    10269+
    reconstructed = cxx_pytree.tree_unflatten(leaves, treespec)
    10270+
    self.assertEqual(x, reconstructed)
    10271+
    10272+
    def fn(x, y):
    10273+
    treespec = cxx_pytree.tree_structure(x)
    10274+
    if isinstance(treespec, cxx_pytree.PyTreeSpec):
    10275+
    return y.sin()
    10276+
    else:
    10277+
    return y.cos()
    10278+
    10279+
    expected = fn(x, y)
    10280+
    fn_opt = torch.compile(fullgraph=True)(fn)
    10281+
    actual = fn_opt(x, y)
    10282+
    10283+
    self.assertEqual(actual, expected)
    10284+
    1025310285
    def test_shape_env_no_recording(self):
    1025410286
    main = ShapeEnv(should_record_events=False)
    1025510287

    torch/_dynamo/polyfills/pytree.py

    Lines changed: 7 additions & 4 deletions
    Original file line numberDiff line numberDiff line change
    @@ -7,12 +7,12 @@
    77
    from collections import deque
    88
    from dataclasses import dataclass, field
    99
    from typing import Any, Callable, Literal, TYPE_CHECKING
    10-
    from typing_extensions import TypeIs
    1110

    1211
    import torch.utils._pytree as python_pytree
    1312
    from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
    1413

    1514
    from ..decorators import substitute_in_graph
    15+
    from ..variables.builtin import polyfill_class_mapping
    1616

    1717

    1818
    if TYPE_CHECKING:
    @@ -317,10 +317,13 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
    317317
    assert callable(self._unflatten_func)
    318318
    return self._unflatten_func(self._metadata, subtrees)
    319319

    320-
    _LEAF_SPEC = PyTreeSpec((), None, None, (), None)
    320+
    _pytreespec_types = (PyTreeSpec, cxx_pytree.PyTreeSpec)
    321+
    cxx_pytree._pytreespec_types = _pytreespec_types
    322+
    polyfill_class_mapping[PyTreeSpec] = _pytreespec_types
    323+
    polyfill_class_mapping[cxx_pytree.PyTreeSpec] = _pytreespec_types
    324+
    _is_pytreespec_instance = cxx_pytree._is_pytreespec_instance
    321325

    322-
    def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
    323-
    return isinstance(obj, PyTreeSpec)
    326+
    _LEAF_SPEC = PyTreeSpec((), None, None, (), None)
    324327

    325328
    @substitute_in_graph( # type: ignore[arg-type]
    326329
    cxx_pytree.tree_flatten,

    torch/_dynamo/variables/builtin.py

    Lines changed: 17 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -119,6 +119,13 @@
    119119
    operator.gt: polyfills.cmp_gt,
    120120
    operator.ge: polyfills.cmp_ge,
    121121
    }
    122+
    # A mapping from polyfilled class and original class to both classes for `isinstance` check
    123+
    # Insert two key-value pairs for each polyfilled class and original class
    124+
    # {
    125+
    # polyfilled_class: (polyfilled_class, original_class),
    126+
    # original_class: (polyfilled_class, original_class),
    127+
    # }
    128+
    polyfill_class_mapping: dict[type, tuple[type, ...]] = {}
    122129

    123130

    124131
    class BuiltinVariable(VariableTracker):
    @@ -1691,6 +1698,16 @@ def check_type(ty):
    16911698
    ],
    16921699
    )
    16931700

    1701+
    if any(tp in polyfill_class_mapping for tp in isinstance_type_tuple):
    1702+
    isinstance_type_tuple = tuple(
    1703+
    dict.fromkeys(
    1704+
    itertools.chain.from_iterable(
    1705+
    polyfill_class_mapping.get(tp, (tp,))
    1706+
    for tp in isinstance_type_tuple
    1707+
    )
    1708+
    )
    1709+
    )
    1710+
    16941711
    try:
    16951712
    val = issubclass(arg_type, isinstance_type_tuple)
    16961713
    except TypeError:

    torch/utils/_cxx_pytree.py

    Lines changed: 9 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -20,7 +20,10 @@
    2020
    from typing_extensions import deprecated, TypeIs
    2121

    2222
    import optree
    23-
    from optree import PyTreeSpec as TreeSpec # direct import for type annotations
    23+
    from optree import ( # direct import for type annotations
    24+
    PyTreeSpec,
    25+
    PyTreeSpec as TreeSpec,
    26+
    )
    2427

    2528
    import torch.utils._pytree as python_pytree
    2629
    from torch.utils._pytree import KeyEntry as KeyEntry
    @@ -230,8 +233,12 @@ def _private_register_pytree_node(
    230233
    )
    231234

    232235

    236+
    # Will be updated in torch._dynamo.polyfilles.pytree
    237+
    _pytreespec_types: tuple[type, ...] = (PyTreeSpec,)
    238+
    239+
    233240
    def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
    234-
    return isinstance(obj, TreeSpec)
    241+
    return isinstance(obj, _pytreespec_types)
    235242

    236243

    237244
    def tree_is_leaf(

    0 commit comments

    Comments
     (0)
    0