8000 [Dynamo] added warning message for tracing lru_cache wrapped function… · pytorch/pytorch@aac30ef · GitHub
[go: up one dir, main page]

Skip to content

Commit aac30ef

Browse files
Sidharth123-cpupytorchmergebot
authored andcommitted
[Dynamo] added warning message for tracing lru_cache wrapped functions (#153744)
Pull Request resolved: #153744 Approved by: https://github.com/williamwen42
1 parent e88c4db commit aac30ef

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

test/dynamo/test_functions.py

+22
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,28 @@ def test_inline_script_if_tracing_fn_with_default_args(a, b):
151151
def test_inline_lru_cache_fn_with_default_args(a, b):
152152
return inline_lru_cache_fn_with_default_args(a, 2, b)
153153

154+
def test_lru_cache_warning_issued_during_tracing(self):
155+
import warnings
156+
from functools import lru_cache
157+
158+
@lru_cache
159+
def foo(x):
160+
return x + 1
161+
162+
with warnings.catch_warnings(record=True) as w:
163+
warnings.simplefilter("always")
164+
torch.compile(foo, backend="eager")(torch.randn(4))
165+
166+
for warning in w:
167+
warning_message = str(warning.message)
168+
if (
169+
"Dynamo detected 10000 a call to a `functools.lru_cache` wrapped function"
170+
in warning_message
171+
):
172+
break
173+
else:
174+
self.assertTrue(False, "Expected warning about lru_cache not found")
175+
154176
@make_test
155177
def test_add(a, b):
156178
return a + b

test/dynamo/test_repros.py

+24
Original file line numberDiff line numberDiff line change
@@ -4717,6 +4717,29 @@ def foo(a):
47174717
):
47184718
f_compiled(a)
47194719

4720+
# https://github.com/pytorch/pytorch/issues/146598
4721+
@unittest.expectedFailure
4722+
def test_lru_cache_tracing(self):
4723+
from functools import lru_cache
4724+
4725+
counter = 0
4726+
4727+
@lru_cache
4728+
def cached_fn(x):
4729+
nonlocal counter
4730+
counter += 1
4731+
return x + 1
4732+
4733+
compiled_fn = torch.compile(cached_fn, backend="eager")
4734+
4735+
t = torch.randn(2, 2)
4736+
result1 = compiled_fn(t)
4737+
self.assertEqual(counter, 1)
4738+
4739+
result2 = compiled_fn(t)
4740+
self.assertEqual(counter, 1)
4741+
self.assertEqual(result1, result2)
4742+
47204743
def test_dont_aggressively_write_assert(self):
47214744
record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
47224745

@@ -5431,6 +5454,7 @@ def forward(self, x):
54315454
mod = Mod()
54325455
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
54335456
x = torch.randn(4)
5457+
54345458
self.assertEqual(mod(x), opt_mod(x))
54355459

54365460
def test_enum(self):

torch/_dynamo/variables/functions.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import itertools
3030
import sys
3131
import types
32+
import warnings
3233
from collections.abc import Sequence
3334
from types import FunctionType
3435
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
@@ -445,6 +446,7 @@ def call_function(
445446
kwargs: "dict[str, VariableTracker]",
446447
) -> "VariableTracker":
447448
# Handle patch_dynamo_config call
449+
448450
if self.fn is torch._dynamo.patch_dynamo_config:
449451
try:
450452
args_const = [arg.as_python_constant() for arg in args]
@@ -1534,6 +1536,12 @@ def call_function(
15341536
args: "list[VariableTracker]",
15351537
kwargs: "dict[str, VariableTracker]",
15361538
) -> "VariableTracker":
1539+
if hasattr(self.wrapper_obj, "cache_info"):
1540+
warnings.warn(
1541+
"Dynamo detected a call to a `functools.lru_cache` wrapped function."
1542+
"Dynamo currently ignores `functools.lru_cache` and directly traces the wrapped function."
1543+
"`functools.lru_cache` wrapped functions that read outside state may not be traced soundly."
1544+
)
15371545
return variables.UserFunctionVariable(
15381546
polyfills.getattr_and_trace
15391547
).call_function(

0 commit comments

Comments
 (0)
0