8000 [Dynamo][pytree] handle `isinstance(...)` check for polyfilled class by XuehaiPan · Pull Request #146921 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Dynamo][pytree] handle isinstance(...) check for polyfilled class #146921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/dynamo/test_misc.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -10250,6 +10250,38 @@ def fn(x, y):

self.assertEqual(actual, expected)

@unittest.skipIf(cxx_pytree is None, "Test for C++ pytree polyfill infra")
def test_pytreespec_isinstance_check(self):
from torch._dynamo.polyfills import pytree as polyfilled_cxx_pytree

@torch.compile(fullgraph=True)
def fn(x, y):
leaves, treespec = cxx_pytree.tree_flatten(x)
return leaves, treespec, y.sin()

y = torch.randn(3)
x = [1, [2, [3, 4]]]
leaves, treespec, _ = fn(x, y)
# Compiled function returns an instance of the polyfilled class instead of the original class
self.assertIsInstance(treespec, polyfilled_cxx_pytree.PyTreeSpec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is, it is wrong for the compiled function to return an instance of the polyfilled class instead of the original class. The compiled function needs to return an instance of the original class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We either need Animesh's polyfill infra for classes, or we need to actually make a TreeSpecVariable in Dynamo that has a reconstruct method. I'd prefer hardening the polyfill infra for classes because that is more generically applicable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiled function needs to return an instance of the original class.

That is the ideal solution for polyfilling a class. We need to find a way to batch register the Python version of the polyfill methods of the C++ class. cc @anijain2305 about the class polyfill infra design.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a second thought, I think polyfilling the methods and using a variable tracker of instance original class during inlining will cause performance issues. Also, it is not easy to polyfill C++ descriptors and support the pybind11 property and read-only property.

The compiled function needs to return an instance of the original class.

We should use the polyfilled class object while inlining the graph and find a way to convert between the original/polyfilled class instances at the graph boundaries.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the polyfilled class object while inlining the graph and find a way to convert between the original/polyfilled class instances at the graph boundaries.

Yes

# Must not raise exceptions that allow partially compiled programs to mix polyfilled classes
# with original classes in different parts of the program
reconstructed = cxx_pytree.tree_unflatten(leaves, treespec)
self.assertEqual(x, reconstructed)

def fn(x, y):
treespec = cxx_pytree.tree_structure(x)
if isinstance(treespec, cxx_pytree.PyTreeSpec):
return y.sin()
else:
return y.cos()

expected = fn(x, y)
fn_opt = torch.compile(fullgraph=True)(fn)
actual = fn_opt(x, y)

self.assertEqual(actual, expected)

def test_shape_env_no_recording(self):
main = ShapeEnv(should_record_events=False)

Expand Down
11 changes: 7 additions & 4 deletions torch/_dynamo/polyfills/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, TYPE_CHECKING
from typing_extensions import TypeIs

import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES

from ..decorators import substitute_in_graph
from ..variables.builtin import polyfill_class_mapping


if TYPE_CHECKING:
Expand Down Expand Up @@ -317,10 +317,13 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)

_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
_pytreespec_types = (PyTreeSpec, cxx_pytree.PyTreeSpec)
cxx_pytree._pytreespec_types = _pytreespec_types
polyfill_class_mapping[PyTreeSpec] = _pytreespec_types
polyfill_class_mapping[cxx_pytree.PyTreeSpec] = _pytreespec_types
_is_pytreespec_instance = cxx_pytree._is_pytreespec_instance

def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)

@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten,
Expand Down
17 changes: 17 additions & 0 deletions torch/_dynamo/variables/builtin.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@
operator.gt: polyfills.cmp_gt,
operator.ge: polyfills.cmp_ge,
}
# A mapping from polyfilled class and original class to both classes for `isinstance` check
# Insert two key-value pairs for each polyfilled class and original class
# {
# polyfilled_class: (polyfilled_class, original_class),
# original_class: (polyfilled_class, original_class),
# }
polyfill_class_mapping: dict[type, tuple[type, ...]] = {}


class BuiltinVariable(VariableTracker):
Expand Down Expand Up @@ -1706,6 +1713,16 @@ def check_type(ty):
],
)

if any(tp in polyfill_class_mapping for tp in isinstance_type_tuple):
isinstance_type_tuple = tuple(
dict.fromkeys(
itertools.chain.from_iterable(
polyfill_class_mapping.get(tp, (tp,))
for tp in isinstance_type_tuple
)
)
)

try:
# NB: `isinstance()` does not call `__subclasscheck__` but use `__instancecheck__`.
# But usually `isinstance(obj, type_info)` and `issubclass(type(obj), type_info)` gives
Expand Down
11 changes: 9 additions & 2 deletions torch/utils/_cxx_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from typing_extensions import deprecated, TypeIs

import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
from optree import ( # direct import for type annotations
PyTreeSpec,
PyTreeSpec as TreeSpec,
)

import torch.utils._pytree as python_pytree
from torch.utils._pytree import KeyEntry as KeyEntry
Expand Down Expand Up @@ -234,8 +237,12 @@ def _private_register_pytree_node(
)


# Will be updated in torch._dynamo.polyfilles.pytree
_pytreespec_types: tuple[type, ...] = (PyTreeSpec,)


def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)
return isinstance(obj, _pytreespec_types)


def tree_is_leaf(
Expand Down
Loading
0