8000 [Dynamo][pytree] handle `isinstance(...)` check for polyfilled class · pytorch/pytorch@8bbb190 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8bbb190

Browse files
committed
[Dynamo][pytree] handle isinstance(...) check for polyfilled class
ghstack-source-id: 0246bf1 Pull Request resolved: #146921
1 parent c740414 commit 8bbb190

File tree

4 files changed

+65
-6
lines changed

4 files changed

+65
-6
lines changed

test/dynamo/test_misc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10250,6 +10250,38 @@ def fn(x, y):
1025010250

1025110251
self.assertEqual(actual, expected)
1025210252

10253+
@unittest.skipIf(cxx_pytree is None, "Test for C++ pytree polyfill infra")
10254+
def test_pytreespec_isinstance_check(self):
10255+
from torch._dynamo.polyfills import pytree as polyfilled_cxx_pytree
10256+
10257+
@torch.compile(fullgraph=True)
10258+
def fn(x, y):
10259+
leaves, treespec = cxx_pytree.tree_flatten(x)
10260+
return leaves, treespec, y.sin()
10261+
10262+
y = torch.randn(3)
10263+
x = [1, [2, [3, 4]]]
10264+
leaves, treespec, _ = fn(x, y)
10265+
# Compiled function returns an instance of the polyfilled class instead of the original class
10266+
self.assertIsInstance(treespec, polyfilled_cxx_pytree.PyTreeSpec)
10267+
# Must not raise exceptions that allow partially compiled programs to mix polyfilled classes
10268+
# with original classes in different parts of the program
10269+
reconstructed = cxx_pytree.tree_unflatten(leaves, treespec)
10270+
self.assertEqual(x, reconstructed)
10271+
10272+
def fn(x, y):
10273+
treespec = cxx_pytree.tree_structure(x)
10274+
if isinstance(treespec, cxx_pytree.PyTreeSpec):
10275+
return y.sin()
10276+
else:
10277+
return y.cos()
10278+
10279+
expected = fn(x, y)
10280+
fn_opt = torch.compile(fullgraph=True)(fn)
10281+
actual = fn_opt(x, y)
10282+
10283+
self.assertEqual(actual, expected)
10284+
1025310285
def test_shape_env_no_recording(self):
1025410286
main = ShapeEnv(should_record_events=False)
1025510287

torch/_dynamo/polyfills/pytree.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from collections import deque
88
from dataclasses import dataclass, field
99
from typing import Any, Callable, Literal, TYPE_CHECKING
10-
from typing_extensions import TypeIs
1110

1211
import torch.utils._pytree as python_pytree
1312
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
1413

1514
from ..decorators import substitute_in_graph
15+
from ..variables.builtin import polyfill_class_mapping
1616

1717

1818
if TYPE_CHECKING:
@@ -317,10 +317,13 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
317317
assert callable(self._unflatten_func)
318318
return self._unflatten_func(self._metadata, subtrees)
319319

320-
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
320+
_pytreespec_types = (PyTreeSpec, cxx_pytree.PyTreeSpec)
321+
cxx_pytree._pytreespec_types = _pytreespec_types
322+
polyfill_class_mapping[PyTreeSpec] = _pytreespec_types
323+
polyfill_class_mapping[cxx_pytree.PyTreeSpec] = _pytreespec_types
324+
_is_pytreespec_instance = cxx_pytree._is_pytreespec_instance
321325

322-
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
323-
return isinstance(obj, PyTreeSpec)
326+
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
324327

325328
@substitute_in_graph( # type: ignore[arg-type]
326329
cxx_pytree.tree_flatten,

torch/_dynamo/variables/builtin.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@
119119
operator.gt: polyfills.cmp_gt,
120120
operator.ge: polyfills.cmp_ge,
121121
}
122+
# A mapping from polyfilled class and original class to both classes for `isinstance` check
123+
# Insert two key-value pairs for each polyfilled class and original class
124+
# {
125+
# polyfilled_class: (polyfilled_class, original_class),
126+
# original_class: (polyfilled_class, original_class),
127+
# }
128+
polyfill_class_mapping: dict[type, tuple[type, ...]] = {}
122129

123130

124131
class BuiltinVariable(VariableTracker):
@@ -1691,6 +1698,16 @@ def check_type(ty):
16911698
],
16921699
)
16931700

1701+
if any(tp in polyfill_class_mapping for tp in isinstance_type_tuple):
1702+
isinstance_type_tuple = tuple(
1703+
dict.fromkeys(
1704+
itertools.chain.from_iterable(
1705+
polyfill_class_mapping.get(tp, (tp,))
1706+
for tp in isinstance_type_tuple
1707+
)
1708+
)
1709+
)
1710+
16941711
try:
16951712
val = issubclass(arg_type, isinstance_type_tuple)
16961713
except TypeError:

torch/utils/_cxx_pytree.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from typing_extensions import deprecated, TypeIs
2121

2222
import optree
23-
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
23+
from optree import ( # direct import for type annotations
24+
PyTreeSpec,
25+
PyTreeSpec as TreeSpec,
26+
)
2427

2528
import torch.utils._pytree as python_pytree
2629
from torch.utils._pytree import KeyEntry as KeyEntry
@@ -230,8 +233,12 @@ def _private_register_pytree_node(
230233
)
231234

232235

236+
# Will be updated in torch._dynamo.polyfilles.pytree
237+
_pytreespec_types: tuple[type, ...] = (PyTreeSpec,)
238+
239+
233240
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
234-
return isinstance(obj, TreeSpec)
241+
return isinstance(obj, _pytreespec_types)
235242

236243

237244
def tree_is_leaf(

0 commit comments

Comments
 (0)
0