8000 [dynamo][user-defined] Unify standard and non-standard __new__ codebase · pytorch/pytorch@af0f583 · GitHub
[go: up one dir, main page]

Skip to content

Commit af0f583

Browse files
committed
[dynamo][user-defined] Unify standard and non-standard __new__ codebase
ghstack-source-id: 6338f8d Pull Request resolved: #146737
1 parent b6ff50c commit af0f583

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

torch/_dynamo/variables/user_defined.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -607,17 +607,6 @@ def call_function(
607607
)
608608
var.call_method(tx, "__init__", args, kwargs)
609609
return var
610-
elif (
611-
self.is_standard_new()
612-
and SideEffects.cls_supports_mutation_side_effects(self.value)
613-
and self.source
614-
):
615-
var = tx.output.side_effects.track_new_user_defined_object(
616-
variables.BuiltinVariable(object), self, args
617-
)
618-
with do_not_convert_to_tracable_parameter():
619-
var.call_method(tx, "__init__", args, kwargs)
620-
return var
621610
elif (
622611
variables.RestrictedListSubclassVariable.is_matching_cls(self.value)
623612
and self.source
@@ -678,18 +667,15 @@ def call_function(
678667
# types.MappingProxyType is a read-only proxy of the dict. If the
679668
# original dict changes, the changes are reflected in proxy as well.
680669
return variables.MappingProxyVariable(args[0])
681-
elif (
682-
not self.is_standard_new()
683-
and SideEffects.cls_supports_mutation_side_effects(self.value)
684-
and self.source
685-
):
686-
return tx.inline_user_function_return(
687-
VariableTracker.build(
688-
tx, polyfills.instantiate_user_defined_class_object
689-
),
690-
[self, *args],
691-
kwargs,
692-
)
670+
elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source:
671+
with do_not_convert_to_tracable_parameter():
672+
return tx.inline_user_function_return(
673+
VariableTracker.build(
674+
tx, polyfills.instantiate_user_defined_class_object
675+
),
676+
[self, *args],
677+
kwargs,
678+
)
693679
return super().call_function(tx, args, kwargs)
694680

695681
def is_standard_new(self):
@@ -1179,6 +1165,11 @@ def var_getattr(self, tx: "InstructionTranslator", name):
11791165
elif getattr_fn is not None:
11801166
unimplemented("UserDefined with non-function __getattr__")
11811167

1168+
from ..mutation_guard import unpatched_nn_module_init
1169+
1170+
if subobj is torch.nn.Module.__init__:
1171+
subobj = unpatched_nn_module_init
1172+
11821173
if isinstance(subobj, property):
11831174
if self.source:
11841175
# Read the class attribute to reach the property

0 commit comments

Comments
 (0)
0