8000 [dynamo][not ready] polyfill infra for classes by anijain2305 · Pull Request #146678 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo][not ready] polyfill infra for classes #146678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions torch/_dynamo/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,33 @@ def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable:
return wrapper


def substitute_class(original_class, supports_reconstruction=True):
"""
Register a polyfill handler for a class, usually a C++ class from the C extension, to be
used in place of the original class when inlining the original class in the graph.

.. note::

The polyfill handler is only used when inlining the original class. It is not used when
the original class is called directly. In the eager mode, the decorated class calls
the performant C++ class rather than the polyfill handler.
"""

def inner(traceable_class):
assert hasattr(traceable_class, "convert_to_traceable")
if supports_reconstruction:
assert hasattr(traceable_class, "convert_to_original")
traceable_class.__global_name__ = f"___{traceable_class.__module__}_{traceable_class.__name__}___"

from torch._dynamo.trace_rules import _polyfilled_class_mapping
_polyfilled_class_mapping[original_class] = traceable_class

_polyfilled_class_mapping
return traceable_class

return inner


# Helper function to flatten a tensor subclass and apply a function to
# all inner tensors that match the outer dim. Used to reduce duplication
# across the various marking APIs.
Expand Down
45 changes: 44 additions & 1 deletion torch/_dynamo/polyfills/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Iterable
from typing import Callable, TypeVar

from ..decorators import substitute_in_graph
from ..decorators import substitute_class, substitute_in_graph


__all__ = ["reduce"]
Expand Down Expand Up @@ -45,3 +45,46 @@
value = function(value, element)

return value


@substitute_class(functools.partial, supports_reconstruction=True)
class partial:
Comment on lines +50 to +51
Copy link
Collaborator
@XuehaiPan XuehaiPan Feb 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use closure to implement class-like objects?

@substitute_in_graph(functools.partial, is_embedded_type=True)
def partial(func, /, *args, **keywords):
    def newfunc(*fargs, **fkeywords):
        newkeywords = {**keywords, **fkeywords}
        return func(*args, *fargs, **newkeywords)

    newfunc.func = func
    newfunc.args = args
    newfunc.keywords = keywords
    return newfunc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of partial, it could be ok. But, I am thinking in terms of more complicated C classes that have many methods. We can still use closure and add method names as attributes to the returned function object. But, I think its more intuitive for the user to have one-to-one mapping for their C++ class to Python class.

Another discussion point (which is not a class vs closure discussion) is how to reconstruct the value. For example, in your above example, if we want to reconstruct the partial object, it will be of type function (and not functools.partial). This causes silent correctness issues because the following code could depend on the returned obj being of type functools.partial. This is not a class vs closure issue though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The C++ pytree functions can return a C++ object PyTreeSpec. The pytree polyfills return a Python object with analog methods while setting can_constant_fold_through=False.

See also:

@dataclass(frozen=True)
class PyTreeSpec:
"""Analog for :class:`optree.PyTreeSpec` in Python."""

def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[Any], PyTreeSpec]:

Copy link
Collaborator
@XuehaiPan XuehaiPan Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For functools.partial:

class partial:  # copy the pure-Python version from stdlib
    ...


@substitute_in_graph(
    functools.partial.__new__, 
    # We need to disable constant folding here because we want the function to reference the 
    # `partial` class defined above, not the one in the C++ module. 
    can_constant_fold_through=False, 
) 
def partial_new(*args, **kwargs):
    return partial(*args, **kwargs)

def is_partial_instance(obj):
    return instance(obj, partial)  # the `partial` class defined above

One drawback for polyfills is the isinstance(...) and issubclass(...) checks for polyfilled instances. Neither function-polyfill-with-closure or class-polyfill can resolve this issue elegantly. We also need to provide polyfilled is_partial_instance functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are really good pointers. I think it makes sense to just polyfill functools.partial.__new__ and provide a helper to convert the already created functools.partial objects to polyfill'd partial objects.

Let me think more about the reconstruction part.

"""New function with partial application of the given arguments
and keywords.
"""

__slots__ = "func", "args", "keywords", "__dict__", "__weakref__"

def __new__(cls, func, /, *args, **keywords):

Check failure on line 58 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-untyped-def]

Function is missing a type annotation
if not callable(func):
raise TypeError("the first argument must be callable")

if isinstance(func, partial):
args = func.args + args

Check failure on line 63 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"partial" has no attribute "args"
keywords = {**func.keywords, **keywords}

Check failure on line 64 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"partial" has no attribute "keywords"
func = func.func

Check failure on line 65 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"partial" has no attribute "func"

self = super(partial, cls).__new__(cls)

self.func = func

Check failure on line 69 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"partial" has no attribute "func"
self.args = args

Check failure on line 70 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"partial" has no attribute "args"
self.keywords = keywords

Check failure on line 71 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"partial" has no attribute "keywords"
return self

def __call__(self, /, *args, **keywords):

Check failure on line 74 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-untyped-def]

Function is missing a type annotation
keywords = {**self.keywords, **keywords}

Check failure on line 75 in torch/_dynamo/polyfills/functools.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"partial" has no attribute "keywords"
return self.func(*self.args, *args, **keywords)

@staticmethod
def convert_to_traceable(original_value):
assert isinstance(original_value, functools.partial)
return partial(
original_value.func, *original_value.args, **original_value.keywords
)

@staticmethod
def convert_to_original(value):
assert isinstance(value, partial)
return functools.partial(
value.func, *value.args, **value.keywords
)
3 changes: 3 additions & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,9 @@
return set()


_polyfilled_class_mapping = {}

Check failure on line 3028 in torch/_dynamo/trace_rules.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [var-annotated]

Need type annotation for "_polyfilled_class_mapping" (hint: "_polyfilled_class_mapping: dict[<type>, <type>] = ...")


@FunctionIdSet
def _numpy_function_ids() -> dict[int, str]:
unsupported_funcs = {
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .user_defined import (
MutableMappingVariable,
PolyFilledUserDefinedClassVariable,
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
Expand Down
67 changes: 40 additions & 27 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@
CollectionsNamedTupleFunction,
CollectiveFunctionRewriteVariable,
CreateTMADescriptorVariable,
FunctoolsPartialVariable,
FunctoolsWrapsVariable,
TritonKernelVariable,
UserFunctionVariable,
Expand Down Expand Up @@ -222,6 +221,7 @@
FrozenDataClassVariable,
KeyedJaggedTensorVariable,
MutableMappingVariable,
PolyFilledUserDefinedClassVariable,
SourcelessGraphModuleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
Expand Down Expand Up @@ -670,32 +670,36 @@ def build_key_value(i, k, v):
return build_checkpoint_variable(source=self.source)
elif is_invoke_subgraph(value):
return build_invoke_subgraph_variable(source=self.source)
elif isinstance(value, functools.partial):
func_src = AttrSource(self.get_source(), "func")
func_obj = VariableBuilder(self.tx, func_src)(value.func)

args = []
args_source = AttrSource(self.get_source(), "args")
for i, arg in enumerate(value.args):
args.append(
VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
)

keywords = {}
keywords_source = AttrSource(self.get_source(), "keywords")
for k, v in value.keywords.items():
if not ConstantVariable.is_literal(k):
unimplemented("functools.partial with non-literal keyword")
keywords[k] = VariableBuilder(
self.tx, DictGetItemSource(keywords_source, k)
)(v)

install_guard(
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
)
return FunctoolsPartialVariable(func_obj, args, keywords)
# elif isinstance(value, functools.partial):
# self.install_guards(GuardBuilder.TYPE_MATCH)
# new_value = polyfills.functools.partial(value.func, value.args, value.keywords)
# return UserDefinedObjectVariable(new_value, source=self.source)
# # func_src = AttrSource(self.get_source(), "func")
# # func_obj = VariableBuilder(self.tx, func_src)(value.func)

# # args = []
# # args_source = AttrSource(self.get_source(), "args")
# # for i, arg in enumerate(value.args):
# # args.append(
# # VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
# # )

# # keywords = {}
# # keywords_source = AttrSource(self.get_source(), "keywords")
# # for k, v in value.keywords.items():
# # if not ConstantVariable.is_literal(k):
# # unimplemented("functools.partial with non-literal keyword")
# # keywords[k] = VariableBuilder(
# # self.tx, DictGetItemSource(keywords_source, k)
# # )(v)

# # install_guard(
# # self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
# # keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
# # args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
# # )
# # breakpoint()
# # return FunctoolsPartialVariable(func_obj, args, keywords)
elif is_typing(value):
# typing.List, typing.Mapping, etc.
self.install_guards(GuardBuilder.ID_MATCH)
Expand Down Expand Up @@ -1097,6 +1101,15 @@ def build_key_value(i, k, v):
# unlikely to change, so its ok to skip the guard here.
return MethodWrapperVariable(value)
elif issubclass(type(value), type):
if trace_class := trace_rules._polyfilled_class_mapping.get(value):
return PolyFilledUserDefinedClassVariable.create(
tx=self.tx,
orig_class=value,
orig_source=self.source,
trace_class=trace_class
)


if value in (
torch.utils.hooks.BackwardHook,
torch.nn.Parameter,
Expand Down
7 changes: 5 additions & 2 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,15 +1603,15 @@ def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)

def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type):
def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type_vt):
try:
arg_type = arg.python_type()
except NotImplementedError:
unimplemented(
f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}"
)

isinstance_type = isinstance_type.as_python_constant()
isinstance_type = isinstance_type_vt.as_python_constant()

if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:

Expand Down Expand Up @@ -1652,6 +1652,9 @@ def check_type(ty):
# handle __instancecheck__ defined in user class
if (
isinstance(arg, variables.UserDefinedObjectVariable)
and not isinstance(
isinstance_type_vt, variables.PolyFilledUserDefinedClassVariable
)
and "__instancecheck__" in isinstance_type.__class__.__dict__
):
return variables.ConstantVariable.create(
Expand Down
138 changes: 127 additions & 11 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import types
import warnings
import weakref
from typing import TYPE_CHECKING
from typing import Generic, TYPE_CHECKING, Callable, Any
from typing_extensions import is_typeddict

import torch._dynamo.config
Expand Down Expand Up @@ -440,16 +440,24 @@ def call_function(
elif self.value is weakref.ref:
return variables.WeakRefVariable(args[0])
elif self.value is functools.partial:
if not args:
unimplemented("functools.partial malformed")
# The first arg, a callable (the ctor below will assert on types)
fn = args[0]
rest_args = args[1:]
# guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
# args and keywords
return variables.functions.FunctoolsPartialVariable(
fn, args=rest_args, keywords=kwargs
new_cls_vt = variables.UserDefinedClassVariable(polyfills.functools.partial)
var = tx.output.side_effects.track_object_new_from_user_defined_class(
new_cls_vt
)
var.call_method(tx, "__init__", args, kwargs)
return var
# new_value = functools.partial(identity)
# return UserDefinedObjectVariable(new_value, )
# if not args:
# unimplemented("functools.partial malformed")
# # The first arg, a callable (the ctor below will assert on types)
# fn = args[0]
# rest_args = args[1:]
# # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
# # args and keywords
# return variables.functions.FunctoolsPartialVariable(
# fn, args=rest_args, keywords=kwargs
# )
elif self.value is warnings.catch_warnings and not args:
return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs)
elif self.value is torch.cuda.device and not kwargs and len(args) == 1:
Expand Down Expand Up @@ -724,12 +732,16 @@ def __init__(
) -> None:
super().__init__(**kwargs)
self.value = value
if is_polyfilled:
assert value_type is not None, "polyfill must provide the original type"
self.value_type = value_type or type(value)
assert type(value) is self.value_type
if not is_polyfilled:
assert type(value) is self.value_type
# This is used with __new__, when the new object is sourceless but the user class can be sourceful.
self.cls_source = cls_source
if cls_source is None and self.source is not None:
self.cls_source = TypeSource(self.source)
self.is_polyfilled = is_polyfilled

# These attributes are used to reconstruct the user defined object. The
# pseudo code looks like this. Builtin C __new__ do not support kwargs,
Expand Down Expand Up @@ -1576,3 +1588,107 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke

class RandomVariable(UserDefinedObjectVariable):
pass


class PolyFilledUserDefinedClassVariable(VariableTracker):
@staticmethod
def create(tx, orig_class, orig_source, trace_class):
trace_source = AttrSource(tx.import_source(trace_class.__module__), trace_class.__name__)
trace_vt = UserDefinedClassVariable(trace_class, source=trace_source)

return PolyFilledUserDefinedClassVariable(orig_class, trace_class, trace_vt, source=orig_source)

def __init__(self, original_class, traceable_class, traceable_class_vt, **kwargs) -> None:
self.original_class = original_class
self.traceable_class = traceable_class
self.traceable_class_vt = traceable_class_vt
# # NB - The `value` is changed to the polyfilled class. From here, the
# # polyfilled class is used to create the object.
# self.value = traceable_class

def as_python_constant(self):
return self.original_class

def as_proxy(self):
return self.original_class

def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
obj = self.traceable_class_vt.call_function(tx, args, kwargs)
assert isinstance(obj, UserDefinedObjectVariable)
# return obj
global_name = self.traceable_class.__global_name__
installed_global_name = tx.output.install_global_by_id(global_name, self.traceable_class.convert_to_original)
return PolyFilledUserDefinedObjectVariable(obj, self.original_class, self.traceable_class, installed_global_name, mutation_type=obj.mutation_type)

def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
return self.traceable_class_vt.var_getattr(tx, name)

def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
return self.traceable_class_vt.call_method(tx, name, args, kwargs)


class PolyFilledUserDefinedObjectVariable(VariableTracker):
def __init__(self, udf_vt, original_class, traceable_class, installed_global_name, **kwargs) -> None:
super().__init__(**kwargs)
self.udf_vt = udf_vt
self.original_class = original_class
self.traceable_class = traceable_class
self.installed_global_name = installed_global_name

def reconstruct(self, codegen):
if self.udf_vt not in codegen.tempvars:
unimplemented("Incorrect reconstruction for polyfilled object")

# We have the tempvar for the instance of traceable class. For
# reconstructing to the original class, call traceable_class
# convert_to_original method.

codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_global(self.installed_global_name, add=True),
codegen.create_load(codegen.tempvars[self.udf_vt]),
]
)
)
codegen.extend_output(create_call_function(1, False))


def python_type(self):
# NB - This is intentional. For tracing purpose, we want to ensure that
# the class is considered original class. If not, we will have wrong
# conditionals on isinstance(value, class_type)
return self.original_class


def _forward_to_udf_vt(
name: str,
) -> Callable[[PolyFilledUserDefinedObjectVariable, Any, Any], Any]:
@functools.wraps(getattr(UserDefinedObjectVariable, name))
def forward_to_udf_vt(
self: PolyFilledUserDefinedObjectVariable, *args: Any, **kwargs: Any
) -> Any:
return getattr(self.udf_vt, name)(*args, **kwargs)

return forward_to_udf_vt


def _populate() -> None:
for name, value in UserDefinedObjectVariable.__dict__.items():
if name not in PolyFilledUserDefinedObjectVariable.__dict__:
if callable(value):
setattr(PolyFilledUserDefinedObjectVariable, name, _forward_to_udf_vt(name))


_populate()
Loading
0