8000 [dynamo][user-defined] Unify standard and non-standard __new__ codeba… · pytorch/pytorch@cbbb11d · GitHub
[go: up one dir, main page]

Skip to content

Commit cbbb11d

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][user-defined] Unify standard and non-standard __new__ codebase (#146737)
Pull Request resolved: #146737 Approved by: https://github.com/jansel ghstack dependencies: #146677
1 parent ee8a06f commit cbbb11d

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

torch/_dynamo/variables/user_defined.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -593,17 +593,6 @@ def call_function(
593593
)
594594
var.call_method(tx, "__init__", args, kwargs)
595595
return var
596-
elif (
597-
self.is_standard_new()
598-
and SideEffects.cls_supports_mutation_side_effects(self.value)
599-
and self.source
600-
):
601-
var = tx.output.side_effects.track_new_user_defined_object(
602-
variables.BuiltinVariable(object), self, args
603-
)
604-
with do_not_convert_to_tracable_parameter():
605-
var.call_method(tx, "__init__", args, kwargs)
606-
return var
607596
elif (
608597
variables.RestrictedListSubclassVariable.is_matching_cls(self.value)
609598
and self.source
@@ -664,18 +653,15 @@ def call_function(
664653
# types.MappingProxyType is a read-only proxy of the dict. If the
665654
# original dict changes, the changes are reflected in proxy as well.
666655
return variables.MappingProxyVariable(args[0])
667-
elif (
668-
not self.is_standard_new()
669-
and SideEffects.cls_supports_mutation_side_effects(self.value)
670-
and self.source
671-
):
672-
return tx.inline_user_function_return(
673-
VariableTracker.build(
674-
tx, polyfills.instantiate_user_defined_class_object
675-
),
676-
[self, *args],
677-
kwargs,
678-
)
656+
elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source:
657+
with do_not_convert_to_tracable_parameter():
658+
return tx.inline_user_function_return(
659+
VariableTracker.build(
660+
tx, polyfills.instantiate_user_defined_class_object
661+
),
662+
[self, *args],
663+
kwargs,
664+
)
679665
return super().call_function(tx, args, kwargs)
680666

681667
def is_standard_new(self):
@@ -1165,6 +1151,11 @@ def var_getattr(self, tx: "InstructionTranslator", name):
11651151
elif getattr_fn is not None:
11661152
unimplemented("UserDefined with non-function __getattr__")
11671153

1154+
from ..mutation_guard import unpatched_nn_module_init
1155+
1156+
if subobj is torch.nn.Module.__init__:
1157+
subobj = unpatched_nn_module_init
1158+
11681159
if isinstance(subobj, property):
11691160
if self.source:
11701161
# Read the class attribute to reach the property

0 commit comments

Comments
 (0)
0