8000 [dynamo] handle 3.13.0 __dict__ watcher bug (#138284) · rahulsingh-intel/pytorch@32ddf67 · GitHub
[go: up one dir, main page]

Skip to content

Commit 32ddf67

Browse files
williamwen42rahulsingh-intel
authored andcommitted
[dynamo] handle 3.13.0 __dict__ watcher bug (pytorch#138284)
python/cpython#116115 introduced a bug (python/cpython#125608) where changing the attributes of an object may not fire the dict watchers registered to the object's `__dict__`. It has been fixed by python/cpython#125611 but will only be in 3.13.1+. This PR disables the dict watcher guard shortcut for `__dict__`s on 3.13.0 and warns the user to try using 3.13.1+ instead. We also added a simple test to check for this functionality in the future. Pull Request resolved: pytorch#138284 Approved by: https://github.com/jansel ghstack dependencies: pytorch#138030
1 parent b6d0cdb commit 32ddf67

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

test/dynamo/test_modules.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,26 @@ def test_self_mutating1(self):
12611261
else:
12621262
self.assertExpectedInline(cnt.frame_count, """1""")
12631263

1264+
def test_nn_module_setattr(self):
1265+
class Mod(torch.nn.Module):
1266+
def __init__(self):
1267+
super().__init__()
1268+
self.var = 0
1269+
1270+
@torch.compile(backend="eager", dynamic=False)
1271+
def f(x, m):
1272+
return x + m.var
1273+
1274+
inp = torch.ones(3)
1275+
m = Mod()
1276+
1277+
self.assertEqual(f(inp, m), inp)
1278+
# In 3.13.0, setattr will not fire a __dict__'s watchers,
1279+
# so guards may not be invalidated.
1280+
m.var = 1
1281+
# should trigger a recompile
1282+
self.assertEqual(f(inp, m), inp + 1)
1283+
12641284
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
12651285
def test_generation_tag(self):
12661286
cnt = torch._dynamo.testing.CompileCounter()

torch/_dynamo/guards.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import textwrap
1919
import time
2020
import types
21+
import warnings
2122
import weakref
2223
from contextlib import contextmanager
2324
from copy import deepcopy
@@ -651,6 +652,20 @@ def guard_on_dict_keys_and_order(self, value, guard):
651652
key, get_verbose_code_parts(f"{key_source} == {key!r}", guard)
652653
)
653654

655+
@staticmethod
656+
def _get_generic_dict_manager_example_value(example_value):
657+
# due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115,
658+
# reported in https://github.com/python/cpython/issues/125608,
659+
# fixed by https://github.com/python/cpython/pull/125611), we cannot take
660+
# advantage of __dict__ versions to speed up guard checks.
661+
if sys.version_info >= (3, 13) and sys.version_info < (3, 13, 1):
662+
warnings.warn(
663+
"Guards may run slower on Python 3.13.0. Consider upgrading to Python 3.13.1+.",
664+
RuntimeWarning,
665+
)
666+
return None
667+
return example_value
668+
654669
def getattr_on_nn_module(
655670
self,
656671
source,
@@ -776,7 +791,7 @@ def getitem_on_dict_mgr(
776791
# Guard Manager
777792
mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager(
778793
source=mod_dict_source,
779-
example_value=mod_dict,
794+
example_value=self._get_generic_dict_manager_example_value(mod_dict),
780795
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
781796
)
782797

@@ -1271,7 +1286,7 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None:
12711286
mod_dict_source = f"{guard.name}.__dict__"
12721287
mod_generic_dict_manager = base_manager.get_generic_dict_manager(
12731288
source=mod_dict_source,
1274-
example_value=val.__dict__,
1289+
example_value=self._get_generic_dict_manager_example_value(val.__dict__),
12751290
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
12761291
)
12771292

@@ -2261,12 +2276,16 @@ def add_code_part(code_part, guard, log_only=False):
22612276
structured_guard_fns.append(
22622277
lambda: {
22632278
"code": code_part,
2264-
"stack": structured.from_traceback(guard.stack.summary())
2265-
if guard.stack
2266-
else None,
2267-
"user_stack": structured.from_traceback(guard.user_stack)
2268-
if guard.user_stack
2269-
else None,
2279+
"stack": (
2280+
structured.from_traceback(guard.stack.summary())
2281+
if guard.stack
2282+
else None
2283+
),
2284+
"user_stack": (
2285+
structured.from_traceback(guard.user_stack)
2286+
if guard.user_stack
2287+
else None
2288+
),
22702289
}
22712290
)
22722291

torch/csrc/dynamo/guards.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <torch/csrc/utils/pythoncapi_compat.h>
1818
#include <torch/extension.h>
1919

20+
#include <torch/csrc/dynamo/debug_macros.h>
21+
2022
#ifdef USE_CUDA
2123
#include <ATen/cuda/EmptyTensor.h>
2224
#endif
@@ -655,7 +657,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
655657

656658
static std::unordered_map<PyObject*, uint64_t> dict_version_map;
657659
static int dict_version_watcher_id;
658-
static uint64_t global_dict_version_id = 0;
660+
static uint64_t global_dict_version_id = 1;
659661
static int dict_version_watch_callback(
660662
PyDict_WatchEvent event,
661663
PyObject* dict,

0 commit comments

Comments
 (0)
0