8000 [dynamo][user-defined] User class.__new__ instead of special casing by anijain2305 · Pull Request #146677 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo][user-defined] User class.__new__ instead of special casing #146677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions test/dynamo/test_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ def test_dict_subclass_initialization_in_graph(self):
):

class CustomDict(super_class):
def __new__(self, *args, **kwargs):
return super().__new__(self, *args, **kwargs)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
144 changes: 101 additions & 43 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import collections
import contextlib
import functools
import inspect
import warnings
import weakref
Expand All @@ -21,7 +20,7 @@
from .codegen import PyCodegen
from .exc import SideEffectsError, unimplemented
from .source import GlobalSource, LocalCellSource, LocalSource, Source
from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new, tuple_new
from .utils import is_frozen_dataclass, nn_module_new, object_new
from .variables.base import (
AttributeMutation,
AttributeMutationExisting,
Expand Down Expand Up @@ -282,20 +281,8 @@ def track_object_new(
if user_cls is torch.autograd.function.FunctionCtx:
with warnings.catch_warnings(record=True):
obj = torch.autograd.Function()
elif issubclass(user_cls, torch.nn.Module):
obj = nn_module_new(user_cls)
elif issubclass(user_cls, (dict, collections.OrderedDict)):
obj = dict_new(user_cls)
elif issubclass(user_cls, tuple):
obj = tuple_new(user_cls)
else:
try:
obj = object_new(user_cls)
except TypeError:
# TODO(anijain2305/jansel) - Even though object.__new__ is same
# as user_cls.__new__, calling object.__new__(user_cls) fails
# with TypeError.
unimplemented(f"Unable to construct the object of type {user_cls}")
obj = object_new(user_cls)
variable = variable_cls(
obj,
mutation_type=AttributeMutationNew(cls_source),
Expand All @@ -305,18 +292,27 @@ def track_object_new(
self.keepalive.append(obj)
return variable

def track_object_new_from_user_defined_class(
self,
cls_variable: "variables.UserDefinedClassVariable",
):
cls_source = cls_variable.source
user_cls = cls_variable.value
def get_variable_cls(self, user_cls):
from torch.overrides import TorchFunctionMode

from .variables.ctx_manager import GenericContextWrappingVariable
from .variables.torch_function import TorchFunctionModeVariable
from .variables.user_defined import is_forbidden_context_manager

# Find the variable class
variable_cls: type[
variables.UserDefinedObjectVariable
] = variables.UserDefinedObjectVariable
if issubclass(user_cls, torch.nn.Module):
if issubclass(
user_cls, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls):
variable_cls = TorchFunctionModeVariable
elif (
hasattr(user_cls, "__enter__")
and hasattr(user_cls, "__exit__")
and not is_forbidden_context_manager(user_cls)
):
variable_cls = GenericContextWrappingVariable
elif issubclass(user_cls, torch.nn.Module):
variable_cls = variables.UnspecializedNNModuleVariable
elif issubclass(user_cls, (dict, collections.OrderedDict)):
variable_cls = variables.UserDefinedDictVariable
Expand All @@ -326,14 +322,69 @@ def track_object_new_from_user_defined_class(
variable_cls = variables.MutableMappingVariable
elif is_frozen_dataclass(user_cls):
variable_cls = FrozenDataClassVariable
assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
return variable_cls

def get_example_value(
self,
base_cls_vt,
cls_vt,
init_args,
):
user_cls = cls_vt.value
if issubclass(user_cls, torch.nn.Module):
# TODO(anijain2305) - Is it possible to remove this specialization?
obj = nn_module_new(user_cls)
else:
variable_cls = variables.UserDefinedObjectVariable
if isinstance(base_cls_vt, variables.BuiltinVariable):
base_cls = base_cls_vt.fn
elif isinstance(base_cls_vt, variables.UserDefinedClassVariable):
base_cls = base_cls_vt.value
else:
raise RuntimeError(f"Unexpected base_cls_vt {base_cls_vt}")

assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
assert variables.UserDefinedClassVariable.is_supported_new_method(
base_cls.__new__
)
# TODO(anijain2305) - Consider adding get_example_value method to
# each VT to get an example value for all args. As we expand the
# scope to other __new__ methods, we might need to call __new__ with
# init_args (like functools.partial)
# init_args = [arg.get_example_value() for arg in init_args]
# obj = base_cls.__new__(user_cls, *init_args)

obj = base_cls.__new__(user_cls)
return obj

variable_cls = functools.partial(variable_cls, cls_source=cls_source)
def track_new_user_defined_object(
self,
base_cls_vt,
cls_vt,
init_args,
):
"""
Creates a UserDefinedObjectVariable (or its subclass) variable tracker
and mark it for attribute mutation tracking.

Also records the variable trackers to call __new__ method on
reconstruction. Roughly, the reconstruction looks like this
base_cls_vt.__new__(user_cls, *init_args)
"""
cls_source = cls_vt.source
user_cls = cls_vt.value
variable_cls = self.get_variable_cls(user_cls)
obj = self.get_example_value(base_cls_vt, cls_vt, init_args)

return self.track_object_new(cls_source, user_cls, variable_cls, {})
variable = variable_cls(
obj,
cls_source=cls_vt.source,
base_cls_vt=base_cls_vt,
init_args=init_args,
mutation_type=AttributeMutationNew(cls_source),
)
self.id_to_variable[id(obj)] = variable
self.keepalive.append(obj)
return variable

def track_cell_new(
self,
Expand Down Expand Up @@ -456,13 +507,6 @@ def mutation(self, var):
def _get_modified_vars(self):
return [var for var in self.id_to_variable.values() if self.is_modified(var)]

def get_new_function(self, var):
if isinstance(var, variables.UserDefinedDictVariable):
return "dict_new"
elif isinstance(var, variables.UserDefinedTupleVariable):
return "tuple_new"
return "object_new"

def codegen_save_tempvars(self, cg: PyCodegen):
# Make sure we codegen these modified VT to their source by default, so
# that mutation and aliasing are properly accounted for.
Expand All @@ -486,17 +530,31 @@ def codegen_save_tempvars(self, cg: PyCodegen):
elif isinstance(var.mutation_type, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, self.get_new_function(var)

# Reconstruct the bytecode for
# base_cls.__new__(user_cls, *args)

if isinstance(var, variables.UserDefinedObjectVariable):

def load_new_method():
assert var.base_cls_vt is not None
cg(var.base_cls_vt) # type: ignore[attr-defined]
cg.extend_output([cg.create_load_attr("__new__")])

cg.add_push_null(load_new_method)
else:
cg.add_push_null(
lambda: cg.load_import_from(utils.__name__, "object_new")
)
)
cg(var.mutation_type.cls_source)
if isinstance(var, variables.UserDefinedTupleVariable) and var.new_args:
cg(var.new_args)
cg.extend_output(create_call_function(2, False))
else:
cg.extend_output(create_call_function(1, False))

# Generate the args to the __new__ method
for arg in var.init_args:
cg(arg)

# Call the __new__ method
cg.extend_output(create_call_function(1 + len(var.init_args), False))

cg.add_cache(var)
var.source = LocalSource(cg.tempvars[var])
else:
Expand Down
83 changes: 40 additions & 43 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,42 +1093,30 @@ def call_method(
and name_var.is_python_constant()
):
return obj.method_setattr_standard(tx, name_var, val)
if self.fn is object and name == "__new__":
assert len(args) == 1
assert len(kwargs) == 0
return tx.output.side_effects.track_object_new_from_user_defined_class(
args[0]
)
if self.fn is object and name == "__init__":
# object.__init__ is a no-op
return variables.ConstantVariable(None)
if self.fn is dict and name == "__new__":
assert len(args) == 1
assert len(kwargs) == 0
dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
if isinstance(args[0], BuiltinVariable) and args[0].fn is dict:
return dict_vt
# We don't have to set the underlying dict_vt in
# UserDefinedDictVariable because it will be set to empty
# ConstDictVariableTracker in the constructor.
return tx.output.side_effects.track_object_new_from_user_defined_class(
args[0]
)
if self.fn is dict and name == "fromkeys":
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)

if self.fn is dict:
resolved_fn = getattr(self.fn, name)
if resolved_fn in dict_methods:
if isinstance(args[0], variables.UserDefinedDictVariable):
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.ConstDictVariable):
return args[0].call_method(tx, name, args[1:], kwargs)
if name == "__new__":
# Supported __new__ methods
if self.fn is object and len(args) == 1:
assert len(kwargs) == 0
return tx.output.side_effects.track_new_user_defined_object(
self, args[0], args[1:]
)

if self.fn is dict and len(args) == 1 and not kwargs:
dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
if isinstance(args[0], BuiltinVariable) and args[0].fn is dict:
return dict_vt
# We don't have to set the underlying dict_vt in
# UserDefinedDictVariable because it will be set to empty
# ConstDictVariableTracker in the constructor.
return tx.output.side_effects.track_new_user_defined_object(
self,
args[0],
args[1:],
)

if self.fn is tuple:
resolved_fn = getattr(self.fn, name)
if (
resolved_fn is tuple.__new__
self.fn is tuple
and len(args) == 2
and args[1].has_unpack_var_sequence(tx)
and not kwargs
Expand All @@ -1140,20 +1128,29 @@ def call_method(
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
return tuple_vt

result = (
tx.output.side_effects.track_object_new_from_user_defined_class(
args[0]
)
result = tx.output.side_effects.track_new_user_defined_object(
self,
args[0],
args[1:],
)
# side_effects data structure does not support methods like
# tuple.__new__ that uses *args parameters for the __new__
# method. Therefore, we manage the *args related functionality
# here. For other datastructures, this is done in the __init__
# method.
result.set_new_args(args[1])
result.set_underlying_tuple_vt(tuple_vt)
return result

if self.fn is object and name == "__init__":
# object.__init__ is a no-op
return variables.ConstantVariable(None)

if self.fn is dict and name == "fromkeys":
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)

if self.fn is dict:
resolved_fn = getattr(self.fn, name)
if resolved_fn in dict_methods:
if isinstance(args[0], variables.UserDefinedDictVariable):
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.ConstDictVariable):
return args[0].call_method(tx, name, args[1:], kwargs)

return super().call_method(tx, name, args, kwargs)

def _call_int_float(self, tx: "InstructionTranslator", arg):
Expand Down
19 changes: 15 additions & 4 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,21 @@ def call_method(
).call_function(tx, [self.objvar] + args, kwargs)
else:
unimplemented("super() nn.Module.__init__")
elif self.objvar.source and inner_fn is object.__new__:
return tx.output.side_effects.track_object_new_from_user_defined_class(
self.objvar
)
elif (
self.objvar.source
and hasattr(inner_fn, "__name__")
and inner_fn.__name__ == "__new__"
and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn)
):
user_cls = inner_fn.__self__
if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins":
user_cls_vt = variables.BuiltinVariable(user_cls)
else:
user_cls_source = source.member
user_cls_vt = variables.UserDefinedClassVariable(
user_cls, source=user_cls_source
)
return user_cls_vt.call_method(tx, "__new__", args, kwargs)
elif isinstance(inner_fn, staticmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
Expand Down
Loading
Loading
0