@@ -957,6 +957,370 @@ def hook(guard_wrapper, f_locals, builder):
957957 opt_fn (torch .randn (4 , 4 ))
958958
959959
960+ class RecursiveDictTagTests (torch ._dynamo .test_case .TestCase ):
961+ def setUp (self ):
962+ self ._prev = torch ._dynamo .config .use_recursive_dict_tags_for_guards
963+ torch ._dynamo .config .use_recursive_dict_tags_for_guards = True
964+
965+ def tearDown (self ):
966+ torch ._dynamo .config .use_recursive_dict_tags_for_guards = self ._prev
967+
968+
969+ class TagSafetyChecks (RecursiveDictTagTests ):
970+ def setUp (self ):
971+ self ._prev = torch ._dynamo .config .use_recursive_dict_tags_for_guards
972+ torch ._dynamo .config .use_recursive_dict_tags_for_guards = True
973+
974+ def tearDown (self ):
975+ torch ._dynamo .config .use_recursive_dict_tags_for_guards = self ._prev
976+
977+ def test_immutable_tag_safe (self ):
978+ class Bar :
979+ pass
980+
981+ class Foo :
982+ def __init__ (self ):
983+ self .a = Bar ()
984+ self .b = torch .randn (4 )
985+ self .c = 3
986+ self .d = (3 , 4 )
987+ self .e = (3 , Bar ())
988+
989+ foo = Foo ()
990+
991+ def fn (x ):
992+ if foo .a :
993+ x = torch .sin (x )
994+ x = x * foo .b + foo .c + foo .d [0 ] + foo .d [1 ] + foo .e [0 ]
995+ if foo .e [1 ]:
996+ x = torch .sin (x )
997+ return x
998+
999+ try :
1000+ from .utils import install_guard_manager_testing_hook
1001+ except ImportError :
1002+ from utils import install_guard_manager_testing_hook
1003+
1004+ def hook (guard_wrapper , f_locals , builder ):
1005+ from torch ._dynamo .source import AttrSource , LocalSource
1006+
1007+ foo_source = LocalSource ("foo" )
1008+
1009+ # Check types of foo.a
1010+ foo_a_source = AttrSource (foo_source , "a" )
1011+ foo_a_mgr = builder .get_guard_manager_from_source (foo_a_source )
1012+ self .assertFalse (foo_a_mgr .is_tag_safe ())
1013+ self .assertFalse (foo_a_mgr .is_tag_safe_root ())
1014+
1015+ # Check types of foo.b
1016+ foo_b_source = AttrSource (foo_source , "b" )
1017+ foo_b_mgr = builder .get_guard_manager_from_source (foo_b_source )
1018+ if torch ._dynamo .config .skip_tensor_guards_with_matching_dict_tags :
1019+ self .assertTrue (foo_b_mgr .is_tag_safe ())
1020+ else :
1021+ self .assertFalse (foo_b_mgr .is_tag_safe ())
1022+
1023+ self .assertFalse (foo_b_mgr .is_tag_safe_root ())
1024+
1025+ # Check types of foo.c
1026+ foo_c_source = AttrSource (foo_source , "c" )
1027+ foo_c_mgr = builder .get_guard_manager_from_source (foo_c_source )
1028+ self .assertTrue (foo_c_mgr .is_tag_safe ())
1029+ self .assertFalse (foo_c_mgr .is_tag_safe_root ())
1030+
1031+ # Check types of foo.d
1032+ foo_d_source = AttrSource (foo_source , "d" )
1033+ foo_d_mgr = builder .get_guard_manager_from_source (foo_d_source )
1034+ self .assertTrue (foo_d_mgr .is_tag_safe ())
1035+ self .assertFalse (foo_d_mgr .is_tag_safe_root ())
1036+
1037+ # Check types of foo.e
1038+ foo_e_source = AttrSource (foo_source , "e" )
1039+ foo_e_mgr = builder .get_guard_manager_from_source (foo_e_source )
1040+ self .assertFalse (foo_e_mgr .is_tag_safe ())
1041+ self .assertFalse (foo_e_mgr .is_tag_safe_root ())
1042+
1043+ opt_fn = torch .compile (fn , backend = "eager" , fullgraph = True )
1044+ with install_guard_manager_testing_hook (hook ):
1045+ opt_fn (torch .randn (4 , 4 ))
1046+
1047+ def test_dict_tag_safe (self ):
1048+ class Foo :
1049+ def __init__ (self ):
1050+ self .a = 4
1051+
1052+ foo = Foo ()
1053+ terminal_dict = {
1054+ "a" : 1 ,
1055+ }
1056+
1057+ tag_safe_dict = {
1058+ "const" : 1 ,
1059+ "tup" : (2 , 3 ),
1060+ "nested_dict" : terminal_dict ,
1061+ }
1062+
1063+ tag_unsafe_dict = {
1064+ "const" : 1 ,
1065+ "foo" : foo ,
1066+ }
1067+
1068+ outer_dict = {
1069+ "safe" : tag_safe_dict ,
1070+ "unsafe" : tag_unsafe_dict ,
1071+ "terminal_dict" : {"a" : 1 },
1072+ }
1073+
1074+ def fn (x ):
1075+ x = x + outer_dict ["safe" ]["const" ]
1076+
1077+ x = x + outer_dict ["safe" ]["tup" ][0 ]
1078+ x = x + outer_dict ["safe" ]["tup" ][1 ]
1079+
1080+ x = x + outer_dict ["safe" ]["nested_dict" ]["a" ]
1081+
1082+ x = x + outer_dict ["unsafe" ]["const" ]
1083+
1084+ x = x + outer_dict ["unsafe" ]["foo" ].a
1085+
1086+ if outer_dict ["terminal_dict" ]:
1087+ x = torch .sin (x )
1088+ return x
1089+
1090+ try :
1091+ from .utils import install_guard_manager_testing_hook
1092+ except ImportError :
1093+ from utils import install_guard_manager_testing_hook
1094+
1095+ def hook (guard_wrapper , f_locals , builder ):
1096+ from torch ._dynamo .source import DictGetItemSource , LocalSource
1097+
1098+ outer_source = LocalSource ("outer_dict" )
1099+
1100+ # Check tagness of outer dict
1101+ outer_mgr = builder .get_guard_manager_from_source (outer_source )
1102+ self .assertFalse (outer_mgr .is_tag_safe ())
1103+ self .assertFalse (outer_mgr .is_tag_safe_root ())
1104+
1105+ # Check tagness of outer["safe"]
1106+ outer_safe_source = DictGetItemSource (outer_source , "safe" )
1107+ outer_safe_mgr = builder .get_guard_manager_from_source (outer_safe_source )
1108+ self .assertTrue (outer_safe_mgr .is_tag_safe ())
1109+ self .assertFalse (outer_safe_mgr .is_tag_safe_root ())
1110+
1111+ # Check tagness of outer["unsafe"]
1112+ outer_unsafe_source = DictGetItemSource (outer_source , "unsafe" )
1113+ outer_unsafe_mgr = builder .get_guard_manager_from_source (
1114+ outer_unsafe_source
1115+ )
1116+ self .assertFalse (outer_unsafe_mgr .is_tag_safe ())
1117+ self .assertFalse (outer_unsafe_mgr .is_tag_safe_root ())
1118+
1119+ # Check tagness of outer["terminal_dict"]
1120+ outer_terminal_source = DictGetItemSource (outer_source , "terminal_dict" )
1121+ outer_terminal_mgr = builder .get_guard_manager_from_source (
1122+ outer_terminal_source
1123+ )
1124+ self .assertTrue (outer_terminal_mgr .is_tag_safe ())
1125+ self .assertFalse (outer_terminal_mgr .is_tag_safe_root ())
1126+
1127+ # Check tagness of outer["safe"]["nested_dict"]
1128+ outer_safe_nested_source = DictGetItemSource (
1129+ outer_safe_source , "nested_dict"
1130+ )
1131+ outer_safe_nested_mgr = builder .get_guard_manager_from_source (
1132+ outer_safe_nested_source
1133+ )
1134+ self .assertTrue (outer_safe_nested_mgr .is_tag_safe ())
1135+ # This should not be marked as a root
1136+ self .assertFalse (outer_safe_nested_mgr .is_tag_safe_root ())
1137+
1138+ opt_fn = torch .compile (fn , backend = "eager" , fullgraph = True )
1139+ with install_guard_manager_testing_hook (hook ):
1140+ opt_fn (torch .randn (4 , 4 ))
1141+
1142+ def test_nn_module_tag_safe (self ):
1143+ class Foo (torch .nn .Module ):
1144+ def __init__ (self ):
1145+ super ().__init__ ()
1146+ self .a = 4
1147+
1148+ def forward (self , x ):
1149+ return x + self .a
1150+
1151+ foo = Foo ()
1152+
1153+ class Baz (torch .nn .Module ):
1154+ def __init__ (self ):
1155+ super ().__init__ ()
1156+ self .foo = foo
1157+
1158+ def forward (self , x ):
1159+ return self .foo (x )
1160+
1161+ baz = Baz ()
1162+
1163+ def fn (x ):
1164+ x = x + baz (x )
1165+ return x
1166+
1167+ try :
1168+ from .utils import install_guard_manager_testing_hook
1169+ except ImportError :
1170+ from utils import install_guard_manager_testing_hook
1171+
1172+ def hook (guard_wrapper , f_locals , builder ):
1173+ from torch ._C ._dynamo .guards import GetGenericDictGuardAccessor
1174+ from torch ._dynamo .source import LocalSource
1175+
1176+ baz_source = LocalSource ("baz" )
1177+
1178+ # Check tagness of baz
1179+ baz_mgr = builder .get_guard_manager_from_source (baz_source )
1180+ self .assertTrue (baz_mgr .is_tag_safe ())
1181+ self .assertTrue (baz_mgr .is_tag_safe_root ())
1182+
1183+ # Check tagness of baz.__dict__
1184+ self .assertTrue (len (baz_mgr .get_accessors ()) == 1 )
1185+ dunder_dict_accessor = baz_mgr .get_accessors ()[0 ]
1186+ self .assertTrue (
1187+ isinstance (dunder_dict_accessor , GetGenericDictGuardAccessor )
1188+ )
1189+
1190+ dunder_dict_mgr = baz_mgr .get_child_managers ()[0 ]
1191+ self .assertTrue (dunder_dict_mgr .is_tag_safe ())
1192+ self .assertFalse (dunder_dict_mgr .is_tag_safe_root ())
1193+
1194+ # Check tagness of baz.__dict__["_modules"]
1195+ modules_mgr = dunder_dict_mgr .get_child_managers ()[0 ]
1196+ self .assertTrue (modules_mgr .is_tag_safe ())
1197+ self .assertFalse (modules_mgr .is_tag_safe_root ())
1198+
1199+ # Check tagness of baz.__dict__["_modules"]["foo"]
1200+ modules_foo_mgr = modules_mgr .get_child_managers ()[0 ]
1201+ self .assertTrue (modules_foo_mgr .is_tag_safe ())
1202+ self .assertFalse (modules_foo_mgr .is_tag_safe_root ())
1203+
1204+ opt_fn = torch .compile (fn , backend = "eager" , fullgraph = True )
1205+ with install_guard_manager_testing_hook (hook ):
1206+ opt_fn (torch .randn (4 , 4 ))
1207+
1208+
1209+ class RecursiveDictGuardTests (RecursiveDictTagTests ):
1210+ def test_disabling (self ):
1211+ class Mod (torch .nn .Module ):
1212+ def __init__ (self ):
1213+ super ().__init__ ()
1214+ self .a = 4
1215+
1216+ def forward (self , x ):
1217+ return x + self .a
1218+
1219+ mod = Mod ()
1220+ mod_to_fail = Mod ()
1221+
1222+ def fn (x ):
1223+ return mod (x )
1224+
1225+ x = torch .randn (4 , 4 )
1226+
1227+ try :
1228+ from .utils import install_guard_manager_testing_hook
1229+ except ImportError :
1230+ from utils import install_guard_manager_testing_hook
1231+
1232+ def basic_hook_test (guard_wrapper , f_locals , builder ):
1233+ from torch ._dynamo .source import LocalSource
1234+
1235+ mod_source = LocalSource ("mod" )
1236+
1237+ # Check tagness of mod
1238+ mod_mgr = builder .get_guard_manager_from_source (mod_source )
1239+ self .assertTrue (mod_mgr .is_tag_safe ())
1240+ self .assertTrue (mod_mgr .is_tag_safe_root ())
1241+ self .assertFalse (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1242+
1243+ for _ in range (10 ):
1244+ self .assertTrue (guard_wrapper .check ({"mod" : mod , "x" : x }))
1245+ self .assertFalse (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1246+
1247+ # Let the guard pass but dict matching fail, this should add new cached entry
1248+ self .assertTrue (guard_wrapper .check ({"mod" : mod_to_fail , "x" : x }))
1249+ self .assertFalse (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1250+
1251+ # Let the guard fail, this should disable dict tag optimization as well
1252+ mod_to_fail .a = 5
1253+ self .assertFalse (guard_wrapper .check ({"mod" : mod_to_fail , "x" : x }))
1254+ self .assertTrue (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1255+
1256+ opt_fn = torch .compile (fn , backend = "eager" , fullgraph = True )
1257+ with install_guard_manager_testing_hook (basic_hook_test ):
1258+ opt_fn (x )
1259+
1260+ # Test that dict tag matching failure leads to disable of dict tag optimization
1261+ torch .compiler .reset ()
1262+ mod = Mod ()
1263+ mod_to_fail = Mod ()
1264+
1265+ def disable_on_dict_tag_match_failure (guard_wrapper , f_locals , builder ):
1266+ from torch ._dynamo .source import LocalSource
1267+
1268+ mod_source = LocalSource ("mod" )
1269+
1270+ # Check tagness of mod
1271+ mod_mgr = builder .get_guard_manager_from_source (mod_source )
1272+ self .assertTrue (mod_mgr .is_tag_safe ())
1273+ self .assertTrue (mod_mgr .is_tag_safe_root ())
1274+ self .assertFalse (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1275+
1276+ for _ in range (10 ):
1277+ self .assertTrue (guard_wrapper .check ({"mod" : mod , "x" : x }))
1278+ self .assertFalse (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1279+
1280+ # Change the mod attr to cause dict tag matching to fail, this still
1281+ # get the guard pass. This should disable the dict tag optimization.
1282+ mod .a = 5
1283+ mod .a = 4
1284+ self .assertTrue (guard_wrapper.check ({"mod" : mod , "x" : x }))
1285+ self .assertTrue (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1286+
1287+ opt_fn = torch .compile (fn , backend = "eager" , fullgraph = True )
1288+ with install_guard_manager_testing_hook (disable_on_dict_tag_match_failure ):
1289+ opt_fn (x )
1290+
1291+ # Test that max size limit breach disables the dict tag optimization
1292+ torch .compiler .reset ()
1293+ mod = Mod ()
1294+ mod_to_fail = Mod ()
1295+
1296+ def max_size_test (guard_wrapper , f_locals , builder ):
1297+ from torch ._dynamo .source import LocalSource
1298+
1299+ mod_source = LocalSource ("mod" )
1300+
1301+ # Check tagness of mod
1302+ mod_mgr = builder .get_guard_manager_from_source (mod_source )
1303+ self .assertTrue (mod_mgr .is_tag_safe ())
1304+ self .assertTrue (mod_mgr .is_tag_safe_root ())
1305+ self .assertFalse (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1306+
1307+ for _ in range (10 ):
1308+ self .assertTrue (guard_wrapper .check ({"mod" : mod , "x" : x }))
1309+ self .assertFalse (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1310+
1311+ # Let the guard pass but dict matching fail, since cache size is set
1312+ # to 1, this would cause dict tag optimization to be disabled.
1313+ self .assertTrue (guard_wrapper .check ({"mod" : mod_to_fail , "x" : x }))
1314+ self .assertTrue (mod_mgr .is_recursive_dict_tag_matching_disabled ())
1315+
1316+ opt_fn = torch .compile (fn , backend = "eager" , fullgraph = True )
1317+ with torch ._dynamo .config .patch (
1318+ max_saved_pointers_for_recursive_dict_tags_check = 1
1319+ ):
1320+ with install_guard_manager_testing_hook (max_size_test ):
1321+ opt_fn (x )
1322+
1323+
9601324if __name__ == "__main__" :
9611325 from torch ._dynamo .test_case import run_tests
9621326
0 commit comments