8000 [reland][dynamo][guards] Consider tensors as immutable for dict tag m… · AmdSampsa/pytorch@5bf1495 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5bf1495

Browse files
anijain2305williamwen42
authored andcommitted
[reland][dynamo][guards] Consider tensors as immutable for dict tag matches (pytorch#141085)
Reland - pytorch#139560 As mentioned in pytorch#130341, using `static py::object` can lead to segfaults. I suspect this is the reason for the import system error seen internally (https://www.internalfb.com/sevmanager/view/469592). In this PR, I am removing the `static` part. This is fine and also the right thing to do because this will catch if user changes the flag in the same process for compiling two different functions. Unfortunately, there is no easy way to trigger this segfault, so I can't write a test. Pull Request resolved: pytorch#141085 Approved by: https://github.com/jansel Co-authored-by: William Wen <williamwen@meta.com>
1 parent 52a9ac3 commit 5bf1495

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

test/dynamo/test_modules.py

Lines c 8000 hanged: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3166,6 +3166,54 @@ def fn(x):
31663166
res = opt_fn(x)
31673167
self.assertEqual(ref, res)
31683168

3169+
@patch.object(
3170+
torch._dynamo.config, "skip_tensor_guards_with_matching_dict_tags", False
3171+
)
3172+
def test_param_requires_grad(self):
3173+
def adjust_model(model):
3174+
to_freeze = model.num_iter % 2 == 0
3175+
if to_freeze:
3176+
for param in model.layer2.parameters():
3177+
param.requires_grad = False
3178+
else:
3179+
for param in model.layer2.parameters():
3180+
param.requires_grad = True
3181+
3182+
class MyModule(torch.nn.Module):
3183+
def __init__(self, input_size, hidden_size, output_size):
3184+
super().__init__()
3185+
3186+
self.layer1 = torch.nn.Linear(hidden_size, hidden_size)
3187+
self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
3188+
3189+
self.num_iter = 0
3190+
3191+
def forward(self, x):
3192+
x = self.layer2(x + self.layer1.bias)
3193+
3194+
self.num_iter += 1
3195+
return x
3196+
3197+
input_size = 1024
3198+
hidden_size = 1024
3199+
output_size = 1
3200+
num_samples = 2048
3201+
features = torch.randn(num_samples, input_size)
3202+
3203+
model = MyModule(input_size, hidden_size, output_size)
3204+
3205+
cnt = torch._dynamo.testing.CompileCounter()
3206+
opt_model = torch.compile(model, backend=cnt, fullgraph=True)
3207+
3208+
for _ in range(3):
3209+
model.zero_grad(True)
3210+
adjust_model(model)
3211+
res = opt_model(features)
3212+
res.sum().backward()
3213+
3214+
# Check that we have recompiled twice, which leads to 3 frames
3215+
self.assertEqual(cnt.frame_count, 3)
3216+
31693217

31703218
if __name__ == "__main__":
31713219
from torch._dynamo.test_case import run_tests

torch/_dynamo/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ def _get_optimize_ddp_mode():
331331
# notice and lead to incorrect result.
332332
skip_no_tensor_aliasing_guards_on_parameters = True
333333

334+
# Considers a tensor immutable if it is one of the values of a dictionary, and
335+
# the dictionary tag is same across invocation calls.
336+
skip_tensor_guards_with_matching_dict_tags = True
337+
334338
# If True, raises exception if TorchDynamo is called with a context manager
335339
raise_on_ctx_manager_usage = True
336340

torch/csrc/dynamo/guards.cpp

Lines changed: 8 additions & 1 deletion
82C7
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,12 @@ std::string get_exception_message() {
903903
}
904904

905905
bool is_immutable_object(py::handle example_value) {
906+
py::object config_module = py::module_::import("torch._dynamo.config");
907+
908+
bool is_tensor_immutable =
909+
config_module.attr("skip_tensor_guards_with_matching_dict_tags")
910+
.cast<bool>();
911+
906912
if (PyTuple_Check(example_value.ptr())) {
907913
// Check that each element is immutable
908914
for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
@@ -913,10 +919,11 @@ bool is_immutable_object(py::handle example_value) {
913919
}
914920
return true;
915921
}
922+
916923
return PyLong_Check(example_value.ptr()) ||
917924
PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
918925
PyUnicode_Check(example_value.ptr()) ||
919-
THPVariable_Check(example_value.ptr());
926+
(is_tensor_immutable && THPVariable_Check(example_value.ptr()));
920927
}
921928

922929
bool is_parameter(py::handle tensor) {

0 commit comments

Comments
 (0)
0