8000 [dynamo][guards] Consider tensors as immutable for dict tag matches (… · pytorch/pytorch@e6ff07f · GitHub
[go: up one dir, main page]

Skip to content

Commit e6ff07f

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][guards] Consider tensors as immutable for dict tag matches (#139560)
This is a bug on the main exposed by #139476 We have dict tag optimization where if the dict tag does not change, we skip guards on all the items of the dict that are "immutable". We considered tensors as immutable in such scenarios. This is critical for guard eval performance, because generally users dont change their parameters. If I try to remove this optimization, we see slowdowns, e.g, 3.03x to 2.95x on conv_mixer TIMM benchamrk. So, I am adding a flag which keeps the current state but allows the users to remove this optimization. Not ideal, but given how serious guard eval perf has to be, we are in the gray are of unsoundness vs performance tradeoff. Pull Request resolved: #139560 Approved by: https://github.com/jansel
1 parent 7f387fa commit e6ff07f

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

test/dynamo/test_modules.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3130,6 +3130,54 @@ def fn(x):
31303130
res = opt_fn(x)
31313131
self.assertEqual(ref, res)
31323132

3133+
@patch.object(
3134+
torch._dynamo.config, "skip_tensor_guards_with_matching_dict_tags", False
3135+
)
3136+
def test_param_requires_grad(self):
3137+
def adjust_model(model):
3138+
to_freeze = model.num_iter % 2 == 0
3139+
if to_freeze:
3140+
for param in model.layer2.parameters():
3141+
param.requires_grad = False
3142+
else:
3143+
for param in model.layer2.parameters():
3144+
param.requires_grad = True
3145+
3146+
class MyModule(torch.nn.Module):
3147+
def __init__(self, input_size, hidden_size, output_size):
3148+
super().__init__()
3149+
3150+
self.layer1 = torch.nn.Linear(hidden_size, hidden_size)
3151+
self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
3152+
3153+
self.num_iter = 0
3154+
3155+
def forward(self, x):
3156+
x = self.layer2(x + self.layer1.bias)
3157+
3158+
self.num_iter += 1
3159+
return x
3160+
3161+
input_size = 1024
3162+
hidden_size = 1024
3163+
output_size = 1
3164+
num_samples = 2048
3165+
features = torch.randn(num_samples, input_size)
3166+
3167+
model = MyModule(input_size, hidden_size, output_size)
3168+
3169+
cnt = torch._dynamo.testing.CompileCounter()
3170+
opt_model = torch.compile(model, backend=cnt, fullgraph=True)
3171+
3172+
for _ in range(3):
3173+
model.zero_grad(True)
3174+
adjust_model(model)
3175+
res = opt_model(features)
3176+
res.sum().backward()
3177+
3178+
# Check that we have recompiled twice, which leads to 3 frames
3179+
self.assertEqual(cnt.frame_count, 3)
3180+
31333181

31343182
if __name__ == "__main__":
31353183
from torch._dynamo.test_case import run_tests

torch/_dynamo/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ def _get_optimize_ddp_mode():
331331
# notice and lead to incorrect result.
332332
skip_no_tensor_aliasing_guards_on_parameters = True
333333

334+
# Considers a tensor immutable if it is one of the values of a dictionary, and
335+
# the dictionary tag is same across invocation calls.
336+
skip_tensor_guards_with_matching_dict_tags = True
337+
334338
# If True, raises exception if TorchDynamo is called with a context manager
335339
raise_on_ctx_manager_usage = True
336340

torch/csrc/dynamo/guards.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,11 @@ std::string get_exception_message() {
886886
}
887887

888888
bool is_immutable_object(py::handle example_value) {
889+
static py::object config_module = py::module_::import("torch._dynamo.config");
890+
bool is_tensor_immutable =
891+
config_module.attr("skip_tensor_guards_with_matching_dict_tags")
892+
.cast<bool>();
893+
889894
if (PyTuple_Check(example_value.ptr())) {
890895
// Check that each element is immutable
891896
for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
@@ -896,10 +901,11 @@ bool is_immutable_object(py::handle example_value) {
896901
}
897902
return true;
898903
}
904+
899905
return PyLong_Check(example_value.ptr()) ||
900906
PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
901907
PyUnicode_Check(example_value.ptr()) ||
902-
THPVariable_Check(example_value.ptr());
908+
(is_tensor_immutable && THPVariable_Check(example_value.ptr()));
903909
}
904910

905911
bool is_parameter(py::handle tensor) {

0 commit comments

Comments
 (0)
0