8000 [dynamo][guards] Recursive dict tag optimization (#159183) · pytorch/pytorch@7eb5fdb · GitHub
[go: up one dir, main page]

Skip to content
10BC0

Commit 7eb5fdb

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][guards] Recursive dict tag optimization (#159183)
Design doc here - https://docs.google.com/document/d/1W29DrWID5miGWlZXspsQVN5U0zydE3kjZpziOXrhuaY/edit?tab=t.0#bookmark=id.sba04iw9sp68 Pull Request resolved: #159183 Approved by: https://github.com/jansel
1 parent f1fb57d commit 7eb5fdb

File tree

5 files changed

+950
-4
lines changed

5 files changed

+950
-4
lines changed

test/dynamo/test_guard_manager.py

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,370 @@ def hook(guard_wrapper, f_locals, builder):
957957
opt_fn(torch.randn(4, 4))
958958

959959

960+
class RecursiveDictTagTests(torch._dynamo.test_case.TestCase):
961+
def setUp(self):
962+
self._prev = torch._dynamo.config.use_recursive_dict_tags_for_guards
963+
torch._dynamo.config.use_recursive_dict_tags_for_guards = True
964+
965+
def tearDown(self):
966+
torch._dynamo.config.use_recursive_dict_tags_for_guards = self._prev
967+
968+
969+
class TagSafetyChecks(RecursiveDictTagTests):
970+
def setUp(self):
971+
self._prev = torch._dynamo.config.use_recursive_dict_tags_for_guards
972+
torch._dynamo.config.use_recursive_dict_tags_for_guards = True
973+
974+
def tearDown(self):
975+
torch._dynamo.config.use_recursive_dict_tags_for_guards = self._prev
976+
977+
def test_immutable_tag_safe(self):
978+
class Bar:
979+
pass
980+
981+
class Foo:
982+
def __init__(self):
983+
self.a = Bar()
984+
self.b = torch.randn(4)
985+
self.c = 3
986+
self.d = (3, 4)
987+
self.e = (3, Bar())
988+
989+
foo = Foo()
990+
991+
def fn(x):
992+
if foo.a:
993+
x = torch.sin(x)
994+
x = x * foo.b + foo.c + foo.d[0] + foo.d[1] + foo.e[0]
995+
if foo.e[1]:
996+
x = torch.sin(x)
997+
return x
998+
999+
try:
1000+
from .utils import install_guard_manager_testing_hook
1001+
except ImportError:
1002+
from utils import install_guard_manager_testing_hook
1003+
1004+
def hook(guard_wrapper, f_locals, builder):
1005+
from torch._dynamo.source import AttrSource, LocalSource
1006+
1007+
foo_source = LocalSource("foo")
1008+
1009+
# Check types of foo.a
1010+
foo_a_source = AttrSource(foo_source, "a")
1011+
foo_a_mgr = builder.get_guard_manager_from_source(foo_a_source)
1012+
self.assertFalse(foo_a_mgr.is_tag_safe())
1013+
self.assertFalse(foo_a_mgr.is_tag_safe_root())
1014+
1015+
# Check types of foo.b
1016+
foo_b_source = AttrSource(foo_source, "b")
1017+
foo_b_mgr = builder.get_guard_manager_from_source(foo_b_source)
1018+
if torch._dynamo.config.skip_tensor_guards_with_matching_dict_tags:
1019+
self.assertTrue(foo_b_mgr.is_tag_safe())
1020+
else:
1021+
self.assertFalse(foo_b_mgr.is_tag_safe())
1022+
1023+
self.assertFalse(foo_b_mgr.is_tag_safe_root())
1024+
1025+
# Check types of foo.c
1026+
foo_c_source = AttrSource(foo_source, "c")
1027+
foo_c_mgr = builder.get_guard_manager_from_source(foo_c_source)
1028+
self.assertTrue(foo_c_mgr.is_tag_safe())
1029+
self.assertFalse(foo_c_mgr.is_tag_safe_root())
1030+
1031+
# Check types of foo.d
1032+
foo_d_source = AttrSource(foo_source, "d")
1033+
foo_d_mgr = builder.get_guard_manager_from_source(foo_d_source)
1034+
self.assertTrue(foo_d_mgr.is_tag_safe())
1035+
self.assertFalse(foo_d_mgr.is_tag_safe_root())
1036+
1037+
# Check types of foo.e
1038+
foo_e_source = AttrSource(foo_source, "e")
1039+
foo_e_mgr = builder.get_guard_manager_from_source(foo_e_source)
1040+
self.assertFalse(foo_e_mgr.is_tag_safe())
1041+
self.assertFalse(foo_e_mgr.is_tag_safe_root())
1042+
1043+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1044+
with install_guard_manager_testing_hook(hook):
1045+
opt_fn(torch.randn(4, 4))
1046+
1047+
def test_dict_tag_safe(self):
1048+
class Foo:
1049+
def __init__(self):
1050+
self.a = 4
1051+
1052+
foo = Foo()
1053+
terminal_dict = {
1054+
"a": 1,
1055+
}
1056+
1057+
tag_safe_dict = {
1058+
"const": 1,
1059+
"tup": (2, 3),
1060+
"nested_dict": terminal_dict,
1061+
}
1062+
1063+
tag_unsafe_dict = {
1064+
"const": 1,
1065+
"foo": foo,
1066+
}
1067+
1068+
outer_dict = {
1069+
"safe": tag_safe_dict,
1070+
"unsafe": tag_unsafe_dict,
1071+
"terminal_dict": {"a": 1},
1072+
}
1073+
1074+
def fn(x):
1075+
x = x + outer_dict["safe"]["const"]
1076+
1077+
x = x + outer_dict["safe"]["tup"][0]
1078+
x = x + outer_dict["safe"]["tup"][1]
1079+
1080+
x = x + outer_dict["safe"]["nested_dict"]["a"]
1081+
1082+
x = x + outer_dict["unsafe"]["const"]
1083+
1084+
x = x + outer_dict["unsafe"]["foo"].a
1085+
1086+
if outer_dict["terminal_dict"]:
1087+
x = torch.sin(x)
1088+
return x
1089+
1090+
try:
1091+
from .utils import install_guard_manager_testing_hook
1092+
except ImportError:
1093+
from utils import install_guard_manager_testing_hook
1094+
1095+
def hook(guard_wrapper, f_locals, builder):
1096+
from torch._dynamo.source import DictGetItemSource, LocalSource
1097+
1098+
outer_source = LocalSource("outer_dict")
1099+
1100+
# Check tagness of outer dict
1101+
outer_mgr = builder.get_guard_manager_from_source(outer_source)
1102+
self.assertFalse(outer_mgr.is_tag_safe())
1103+
self.assertFalse(outer_mgr.is_tag_safe_root())
1104+
1105+
# Check tagness of outer["safe"]
1106+
outer_safe_source = DictGetItemSource(outer_source, "safe")
1107+
outer_safe_mgr = builder.get_guard_manager_from_source(outer_safe_source)
1108+
self.assertTrue(outer_safe_mgr.is_tag_safe())
1109+
self.assertFalse(outer_safe_mgr.is_tag_safe_root())
1110+
1111+
# Check tagness of outer["unsafe"]
1112+
outer_unsafe_source = DictGetItemSource(outer_source, "unsafe")
1113+
outer_unsafe_mgr = builder.get_guard_manager_from_source(
1114+
outer_unsafe_source
1115+
)
1116+
self.assertFalse(outer_unsafe_mgr.is_tag_safe())
1117+
self.assertFalse(outer_unsafe_mgr.is_tag_safe_root())
1118+
1119+
# Check tagness of outer["terminal_dict"]
1120+
outer_terminal_source = DictGetItemSource(outer_source, "terminal_dict")
1121+
outer_terminal_mgr = builder.get_guard_manager_from_source(
1122+
outer_terminal_source
1123+
)
1124+
self.assertTrue(outer_terminal_mgr.is_tag_safe())
1125+
self.assertFalse(outer_terminal_mgr.is_tag_safe_root())
1126+
1127+
# Check tagness of outer["safe"]["nested_dict"]
1128+
outer_safe_nested_source = DictGetItemSource(
1129+
outer_safe_source, "nested_dict"
1130+
)
1131+
outer_safe_nested_mgr = builder.get_guard_manager_from_source(
1132+
outer_safe_nested_source
1133+
)
1134+
self.assertTrue(outer_safe_nested_mgr.is_tag_safe())
1135+
# This should not be marked as a root
1136+
self.assertFalse(outer_safe_nested_mgr.is_tag_safe_root())
1137+
1138+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1139+
with install_guard_manager_testing_hook(hook):
1140+
opt_fn(torch.randn(4, 4))
1141+
1142+
def test_nn_module_tag_safe(self):
1143+
class Foo(torch.nn.Module):
1144+
def __init__(self):
1145+
super().__init__()
1146+
self.a = 4
1147+
1148+
def forward(self, x):
1149+
return x + self.a
1150+
1151+
foo = Foo()
1152+
1153+
class Baz(torch.nn.Module):
1154+
def __init__(self):
1155+
super().__init__()
1156+
self.foo = foo
1157+
1158+
def forward(self, x):
1159+
return self.foo(x)
1160+
1161+
baz = Baz()
1162+
1163+
def fn(x):
1164+
x = x + baz(x)
1165+
return x
1166+
1167+
try:
1168+
from .utils import install_guard_manager_testing_hook
1169+
except ImportError:
1170+
from utils import install_guard_manager_testing_hook
1171+
1172+
def hook(guard_wrapper, f_locals, builder):
1173+
from torch._C._dynamo.guards import GetGenericDictGuardAccessor
1174+
from torch._dynamo.source import LocalSource
1175+
1176+
baz_source = LocalSource("baz")
1177+
1178+
# Check tagness of baz
1179+
baz_mgr = builder.get_guard_manager_from_source(baz_source)
1180+
self.assertTrue(baz_mgr.is_tag_safe())
1181+
self.assertTrue(baz_mgr.is_tag_safe_root())
1182+
1183+
# Check tagness of baz.__dict__
1184+
self.assertTrue(len(baz_mgr.get_accessors()) == 1)
1185+
dunder_dict_accessor = baz_mgr.get_accessors()[0]
1186+
self.assertTrue(
1187+
isinstance(dunder_dict_accessor, GetGenericDictGuardAccessor)
1188+
)
1189+
1190+
dunder_dict_mgr = baz_mgr.get_child_managers()[0]
1191+
self.assertTrue(dunder_dict_mgr.is_tag_safe())
1192+
self.assertFalse(dunder_dict_mgr.is_tag_safe_root())
1193+
1194+
# Check tagness of baz.__dict__["_modules"]
1195+
modules_mgr = dunder_dict_mgr.get_child_managers()[0]
1196+
self.assertTrue(modules_mgr.is_tag_safe())
1197+
self.assertFalse(modules_mgr.is_tag_safe_root())
1198+
1199+
# Check tagness of baz.__dict__["_modules"]["foo"]
1200+
modules_foo_mgr = modules_mgr.get_child_managers()[0]
1201+
self.assertTrue(modules_foo_mgr.is_tag_safe())
1202+
self.assertFalse(modules_foo_mgr.is_tag_safe_root())
1203+
1204+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1205+
with install_guard_manager_testing_hook(hook):
1206+
opt_fn(torch.randn(4, 4))
1207+
1208+
1209+
class RecursiveDictGuardTests(RecursiveDictTagTests):
1210+
def test_disabling(self):
1211+
class Mod(torch.nn.Module):
1212+
def __init__(self):
1213+
super().__init__()
1214+
self.a = 4
1215+
1216+
def forward(self, x):
1217+
return x + self.a
1218+
1219+
mod = Mod()
1220+
mod_to_fail = Mod()
1221+
1222+
def fn(x):
1223+
return mod(x)
1224+
1225+
x = torch.randn(4, 4)
1226+
1227+
try:
1228+
from .utils import install_guard_manager_testing_hook
1229+
except ImportError:
1230+
from utils import install_guard_manager_testing_hook
1231+
1232+
def basic_hook_test(guard_wrapper, f_locals, builder):
1233+
from torch._dynamo.source import LocalSource
1234+
1235+
mod_source = LocalSource("mod")
1236+
1237+
# Check tagness of mod
1238+
mod_mgr = builder.get_guard_manager_from_source(mod_source)
1239+
self.assertTrue(mod_mgr.is_tag_safe())
1240+
self.assertTrue(mod_mgr.is_tag_safe_root())
1241+
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
1242+
1243+
for _ in range(10):
1244+
self.assertTrue(guard_wrapper.check({"mod": mod, "x": x}))
1245+
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
1246+
1247+
# Let the guard pass but dict matching fail, this should add new cached entry
1248+
self.assertTrue(guard_wrapper.check({"mod": mod_to_fail, "x": x}))
1249+
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
1250+
1251+
# Let the guard fail, this should disable dict tag optimization as well
1252+
mod_to_fail.a = 5
1253+
self.assertFalse(guard_wrapper.check({"mod": mod_to_fail, "x": x}))
1254+
self.assertTrue(mod_mgr.is_recursive_dict_tag_matching_disabled())
1255+
1256+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1257+
with install_guard_manager_testing_hook(basic_hook_test):
1258+
opt_fn(x)
1259+
1260+
# Test that dict tag matching failure leads to disable of dict tag optimization
1261+
torch.compiler.reset()
1262+
mod = Mod()
1263+
mod_to_fail = Mod()
1264+
1265+
def disable_on_dict_tag_match_failure(guard_wrapper, f_locals, builder):
1266+
from torch._dynamo.source import LocalSource
1267+
1268+
mod_source = LocalSource("mod")
1269+
1270+
# Check tagness of mod
1271+
mod_mgr = builder.get_guard_manager_from_source(mod_source)
1272+
self.assertTrue(mod_mgr.is_tag_safe())
1273+
self.assertTrue(mod_mgr.is_tag_safe_root())
1274+
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
1275+
1276+
for _ in range(10):
1277+
self.assertTrue(guard_wrapper.check({"mod": mod, "x": x}))
1278+
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
1279+
1280+
# Change the mod attr to cause dict tag matching to fail, this still
1281+
# get the guard pass. This should disable the dict tag optimization.
1282+
mod.a = 5
1283+
mod.a = 4
1284+
self.assertTrue(guard_wrapper.check({"mod": mod, "x": x}))
1285+
self.assertTrue(mod_mgr.is_recursive_dict_tag_matching_disabled())
1286+
1287+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1288+
with install_guard_manager_testing_hook(disable_on_dict_tag_match_failure):
1289+
opt_fn(x)
1290+
1291+
# Test that max size limit breach disables the dict tag optimization
1292+
torch.compiler.reset()
1293+
mod = Mod()
1294+
mod_to_fail = Mod()
1295+
1296+
def max_size_test(guard_wrapper, f_locals, builder):
1297+
from torch._dynamo.source import LocalSource
1298+
1299+
mod_source = LocalSource("mod")
1300+
1301+
# Check tagness of mod
1302+
mod_mgr = builder.get_guard_manager_from_source(mod_source)
1303+
self.assertTrue(mod_mgr.is_tag_safe())
1304+
self.assertTrue(mod_mgr.is_tag_safe_root())
1305+
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
1306+
1307+
for _ in range(10):
1308+
self.assertTrue(guard_wrapper.check({"mod": mod, "x": x}))
1309+
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
1310+
1311+
# Let the guard pass but dict matching fail, since cache size is set
1312+
# to 1, this would cause dict tag optimization to be disabled.
1313+
self.assertTrue(guard_wrapper.check({"mod": mod_to_fail, "x": x}))
1314+
self.assertTrue(mod_mgr.is_recursive_dict_tag_matching_disabled())
1315+
1316+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1317+
with torch._dynamo.config.patch(
1318+
max_saved_pointers_for_recursive_dict_tags_check=1
1319+
):
1320+
with install_guard_manager_testing_hook(max_size_test):
1321+
opt_fn(x)
1322+
1323+
9601324
if __name__ == "__main__":
9611325
from torch._dynamo.test_case import run_tests
9621326

torch/_C/_dynamo/guards.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ class DictGuardManager(GuardManager):
135135
guard_manager_enum,
136136
) -> GuardManager: ...
137137

138+
# Guard accessor stubs
139+
class GuardAccessor: ...
140+
class DictGetItemGuardAccessor(GuardAccessor): ...
141+
class GetGenericDictGuardAccessor(GuardAccessor): ...
142+
138143
def install_object_aliasing_guard(
139144
guard_managers: list[GuardManager],
140145
tensor_names: list[str],

torch/_dynamo/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,12 @@
354354
# Skips guards on func.__defaults__ if the element to be guarded is a constant
355355
skip_guards_on_constant_func_defaults = True
356356

357+
# Speedup guard execution of nested nn modules by recursively checking for dict
358+
# tags to avoid full guard execution.
359+
use_recursive_dict_tags_for_guards = False
360+
361+
max_saved_pointers_for_recursive_dict_tags_check = 256
362+
357363
# If True, raises exception if TorchDynamo is called with a context manager
358364
raise_on_ctx_manager_usage = True
359365

0 commit comments

Comments
 (0)
0