6060 AttrSource ,
6161 CallFunctionNoArgsSource ,
6262 DataclassFieldsSource ,
63- DictGetItemSource ,
6463 GetItemSource ,
6564 RandomValueSource ,
66- TypeDictSource ,
67- TypeMROSource ,
6865 TypeSource ,
6966 UnspecializedParamBufferSource ,
7067)
@@ -1001,9 +998,11 @@ def call_method(
1001998
1002999 # check for methods implemented in C++
10031000 if isinstance (method , types .FunctionType ):
1004- source = None
1005- if self .source :
1006- source = self .get_source_by_walking_mro (name )
1001+ source = (
1002+ None
1003+ if self .source is None
1004+ else AttrSource (AttrSource (self .source , "__class__" ), name )
1005+ )
10071006 # TODO(jansel): add a guard to check for monkey patching?
10081007 from ..mutation_guard import unpatched_nn_module_init
10091008
@@ -1225,40 +1224,12 @@ def get_source_by_walking_mro(self, name):
12251224
12261225 for idx , klass in enumerate (type (self .value ).__mro__ ):
12271226 if name in klass .__dict__ :
1228- if idx != 0 :
1229- mro_source = TypeMROSource (self .cls_source )
1230- klass_source = GetItemSource (mro_source , idx )
1231- else :
1232- klass_source = self .cls_source
1233- dict_source = TypeDictSource (klass_source )
1234- out_source = DictGetItemSource (dict_source , name )
1235-
1236- for absent_idx in range (1 , idx ):
1237- # Insert a guard that the name is not present in the mro hierarchy
1238- mro_source = TypeMROSource (self .cls_source )
1239- klass_source = GetItemSource (mro_source , absent_idx )
1240- dict_source = TypeDictSource (klass_source )
1241- install_guard (
1242- dict_source .make_guard (
1243- functools .partial (
1244- GuardBuilder .DICT_CONTAINS , key = name , invert = True
1245- )
1246- )
1247- )
1248- # Insert a guard that the name is not present in the object __dict__
1249- if (
1250- self .source
1251- and hasattr (self .value , "__dict__" )
1252- and name not in self .value .__dict__
1253- ):
1254- install_guard (
1255- self .source .make_guard (
1256- functools .partial (
1257- GuardBuilder .NOT_PRESENT_IN_GENERIC_DICT , attr = name
1258- )
1259- )
1260- )
1261- return out_source
1227+ mro_source = AttrSource (self .cls_source , "__mro__" )
1228+ klass_source = GetItemSource (mro_source , idx )
1229+ dict_source = AttrSource (klass_source , "__dict__" )
1230+ # TODO(anijain2305) - This is a mapping proxy object. Ideally we
1231+ # should use DictGetItemSource here.
1232+ return GetItemSource (dict_source , name )
12621233
12631234 unimplemented_v2 (
12641235 gb_type = "could not find name in object's mro" ,
@@ -1368,17 +1339,10 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13681339 if subobj is torch .nn .Module .__init__ :
13691340 subobj = unpatched_nn_module_init
13701341
1371- subobj_from_class = inspect .getattr_static (
1372- self .value .__class__ , name , NO_SUCH_SUBOBJ
1373- )
1374- is_accessible_from_type_mro = (
1375- subobj_from_class is subobj and self .cls_source is not None
1376- )
1377-
13781342 if isinstance (subobj , property ):
13791343 if self .source :
13801344 # Read the class attribute to reach the property
1381- source = self .get_source_by_walking_mro ( name )
1345+ source = AttrSource ( AttrSource ( self .source , "__class__" ), name )
13821346 # Get the getter function
13831347 source = AttrSource (source , "fget" )
13841348 return variables .UserMethodVariable (
@@ -1396,11 +1360,6 @@ def var_getattr(self, tx: "InstructionTranslator", name):
13961360 # Safe because `staticmethod.__get__` basically won't trigger user
13971361 # code and just returns the underlying `__func__`:
13981362 # https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100
1399- if is_accessible_from_type_mro :
1400- # Accessing from __dict__ does not resolve the descriptor, it
1401- # returns a staticmethod object, so access the __func__
1402- # attribute to get to the actual function.
1403- source = AttrSource (self .get_source_by_walking_mro (name ), "__func__" )
14041363 func = subobj .__get__ (self .value )
14051364 return VariableTracker .build (tx , func , source )
14061365 elif isinstance (subobj , classmethod ):
@@ -1526,15 +1485,10 @@ def var_getattr(self, tx: "InstructionTranslator", name):
15261485 source = self ._wrap_source (source )
15271486
15281487 if subobj is not NO_SUCH_SUBOBJ :
1529- if is_wrapper_or_member_descriptor (
1530- subobj
1531- ) or torch ._C ._dynamo .utils .is_instancemethod (subobj ):
1488+ if is_wrapper_or_member_descriptor (subobj ):
15321489 options = {"source" : source }
15331490 return variables .GetAttrVariable (self , name , ** options )
15341491 if source :
1535- if is_accessible_from_type_mro :
1536- source = self .get_source_by_walking_mro (name )
1537-
15381492 return variables .LazyVariableTracker .create (subobj , source )
15391493 else :
15401494 # Check if the subobj is accessible from the class itself. If the class source is known, we can create a
0 commit comments