8000 Revert "[dynamo][guards] Consider tensors as immutable for dict tag m… · pytorch/pytorch@4d5cc1b · GitHub
[go: up one dir, main page]

Skip to content

Commit 4d5cc1b

Browse files
Revert "[dynamo][guards] Consider tensors as immutable for dict tag matches (#139560)"
This reverts commit e6ff07f. Reverted #139560 on behalf of https://github.com/ZainRizvi due to Sorry but this seems to be breaking internal tests. Please see D65430317 for more details ([comment](#139560 (comment)))
1 parent a2bc2e3 commit 4d5cc1b

File tree

3 files changed

+1
-59
lines changed

3 files changed

+1
-59
lines changed

test/dynamo/test_modules.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3136,54 +3136,6 @@ def fn(x):
31363136
res = opt_fn(x)
31373137
self.assertEqual(ref, res)
31383138

3139-
@patch.object(
3140-
torch._dynamo.config, "skip_tensor_guards_with_matching_dict_tags", False
3141-
)
3142-
def test_param_requires_grad(self):
3143-
def adjust_model(model):
3144-
to_freeze = model.num_iter % 2 == 0
3145-
if to_freeze:
3146-
for param in model.layer2.parameters():
3147-
param.requires_grad = False
3148-
else:
3149-
for param in model.layer2.parameters():
3150-
param.requires_grad = True
3151-
3152-
class MyModule(torch.nn.Module):
3153-
def __init__(self, input_size, hidden_size, output_size):
3154-
super().__init__()
3155-
3156-
self.layer1 = torch.nn.Linear(hidden_size, hidden_size)
3157-
self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
3158-
3159-
self.num_iter = 0
3160-
3161-
def forward(self, x):
3162-
x = self.layer2(x + self.layer1.bias)
3163-
3164-
self.num_iter += 1
3165-
return x
3166-
3167-
input_size = 1024
3168-
hidden_size = 1024
3169-
output_size = 1
3170-
num_samples = 2048
3171-
features = torch.randn(num_samples, input_size)
3172-
3173-
model = MyModule(input_size, hidden_size, output_size)
3174-
3175-
cnt = torch._dynamo.testing.CompileCounter()
3176-
opt_model = torch.compile(model, backend=cnt, fullgraph=True)
3177-
3178-
for _ in range(3):
3179-
model.zero_grad(True)
3180-
adjust_model(model)
3181-
res = opt_model(features)
3182-
res.sum().backward()
3183-
3184-
# Check that we have recompiled twice, which leads to 3 frames
3185-
self.assertEqual(cnt.frame_count, 3)
3186-
31873139

31883140
if __name__ == "__main__":
31893141
from torch._dynamo.test_case import run_tests

torch/_dynamo/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,6 @@ def _get_optimize_ddp_mode():
331331
# notice and lead to incorrect result.
332332
skip_no_tensor_aliasing_guards_on_parameters = True
333333

334-
# Considers a tensor immutable if it is one of the values of a dictionary, and
335-
# the dictionary tag is same across invocation calls.
336-
skip_tensor_guards_with_matching_dict_tags = True
337-
338334
# If True, raises exception if TorchDynamo is called with a context manager
339335
raise_on_ctx_manager_usage = True
340336

torch/csrc/dynamo/guards.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -886,11 +886,6 @@ std::string get_exception_message() {
886886
}
887887

888888
bool is_immutable_object(py::handle example_value) {
889-
static py::object config_module = py::module_::import("torch._dynamo.config");
890-
bool is_tensor_immutable =
891-
config_module.attr("skip_tensor_guards_with_matching_dict_tags")
892-
.cast<bool>();
893-
894889
if (PyTupl 887A e_Check(example_value.ptr())) {
895890
// Check that each element is immutable
896891
for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
@@ -901,11 +896,10 @@ bool is_immutable_object(py::handle example_value) {
901896
}
902897
return true;
903898
}
904-
905899
return PyLong_Check(example_value.ptr()) ||
906900
PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
907901
PyUnicode_Check(example_value.ptr()) ||
908-
(is_tensor_immutable && THPVariable_Check(example_value.ptr()));
902+
THPVariable_Check(example_value.ptr());
909903
}
910904

911905
bool is_parameter(py::handle tensor) {

0 commit comments

Comments
 (0)
0