8000 [dynamo][user-defined] User class.__new__ instead of special casing · pytorch/pytorch@0c6a000 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0c6a000

Browse files
committed
[dynamo][user-defined] User class.__new__ instead of special casing
ghstack-source-id: e753e96 Pull Request resolved: #146677
1 parent 1b879fd commit 0c6a000

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

torch/_dynamo/side_effects.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .codegen import PyCodegen
2222
from .exc import unimplemented
2323
from .source import GlobalSource, LocalCellSource, LocalSource, Source
24-
from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new, tuple_new
24+
from .utils import is_frozen_dataclass, nn_module_new, object_new
2525
from .variables.base import (
2626
AttributeMutation,
2727
AttributeMutationExisting,
@@ -264,18 +264,17 @@ def track_object_new(
264264
obj = torch.autograd.Function()
265265
elif issubclass(user_cls, torch.nn.Module):
266266
obj = nn_module_new(user_cls)
267-
elif issubclass(user_cls, (dict, collections.OrderedDict)):
268-
obj = dict_new(user_cls)
269-
elif issubclass(user_cls, tuple):
270-
obj = tuple_new(user_cls)
271267
else:
272-
try:
268+
tmp_var_cls = variable_cls
269+
if isinstance(variable_cls, functools.partial):
270+
tmp_var_cls = tmp_var_cls.func
271+
272+
if issubclass(type(tmp_var_cls), type) and issubclass(
273+
tmp_var_cls, variables.UserDefinedObjectVariable
274+
):
275+
obj = user_cls.__new__(user_cls)
276+
else:
273277
obj = object_new(user_cls)
274-
except TypeError:
275-
# TODO(anijain2305/jansel) - Even though object.__new__ is same
276-
# as user_cls.__new__, calling object.__new__(user_cls) fails
277-
# with TypeError.
278-
unimplemented(f"Unable to construct the object of type {user_cls}")
279278
variable = variable_cls(
280279
obj,
281280
mutation_type=AttributeMutationNew(cls_source),
@@ -436,13 +435,6 @@ def mutation(self, var):
436435
def _get_modified_vars(self):
437436
return [var for var in self.id_to_variable.values() if self.is_modified(var)]
438437

439-
def get_new_function(self, var):
440-
if isinstance(var, variables.UserDefinedDictVariable):
441-
return "dict_new"
442-
elif isinstance(var, variables.UserDefinedTupleVariable):
443-
return "tuple_new"
444-
return "object_new"
445-
446438
def codegen_save_tempvars(self, cg: PyCodegen):
447439
# Make sure we codegen these modified VT to their source by default, so
448440
# that mutation and aliasing are properly accounted for.
@@ -466,11 +458,17 @@ def codegen_save_tempvars(self, cg: PyCodegen):
466458
elif isinstance(var.mutation_type, AttributeMutationNew):
467459
if isinstance(var, variables.AutogradFunctionContextVariable):
468460
unimplemented("AutogradFunctionContextVariable escaped")
469-
cg.add_push_null(
470-
lambda: cg.load_import_from(
471-
utils.__name__, self.get_new_function(var)
461+
if isinstance(var, variables.UserDefinedObjectVariable):
462+
463+
def gen_fn():
464+
cg(var.mutation_type.cls_source) # type: ignore[attr-defined]
465+
cg.extend_output([cg.create_load_attr("__new__")])
466+
467+
cg.add_push_null(gen_fn)
468+
else:
469+
cg.add_push_null(
470+
lambda: cg.load_import_from(utils.__name__, "object_new")
472471
)
473-
)
474472
cg(var.mutation_type.cls_source)
475473
if isinstance(var, variables.UserDefinedTupleVariable) and var.new_args:
476474
cg(var.new_args)

torch/_dynamo/variables/user_defined.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
CallFunctionNoArgsSource,
3737
GetItemSource,
3838
RandomValueSource,
39+
TypeSource,
3940
UnspecializedParamBufferSource,
4041
)
4142
from ..utils import (
@@ -710,6 +711,8 @@ def __init__(self, value, value_type=None, cls_source=None, **kwargs) -> None:
710711
assert type(value) is self.value_type
711712
# This is used with __new__, when the new object is sourceless but the user class can be sourceful.
712713
self.cls_source = cls_source
714+
if cls_source is None and self.source is not None:
715+
self.cls_source = TypeSource(self.source)
713716

714717
def __str__(self) -> str:
715718
inner = self.value_type.__name__

0 commit comments

Comments
 (0)
0