10000 [Dynamo] added warning message for tracing lru_cache wrapped functions by Sidharth123-cpu · Pull Request #153744 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Dynamo] added warning message for tracing lru_cache wrapped functions #153744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ def test_inline_script_if_tracing_fn_with_default_args(a, b):
def test_inline_lru_cache_fn_with_default_args(a, b):
return inline_lru_cache_fn_with_default_args(a, 2, b)

def test_lru_cache_warning_issued_during_tracing(self):
import warnings
from functools import lru_cache

@lru_cache
def foo(x):
return x + 1

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
torch.compile(foo, backend="eager")(torch.randn(4))

for warning in w:
warning_message = str(warning.message)
if (
"Dynamo detected a call to a `functools.lru_cache` wrapped function"
in warning_message
):
break
else:
self.assertTrue(False, "Expected warning about lru_cache not found")

@make_test
def test_add(a, b):
return a + b
Expand Down
24 changes: 24 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -4717,6 +4717,29 @@ def foo(a):
):
f_compiled(a)

# https://github.com/pytorch/pytorch/issues/146598
@unittest.expectedFailure
def test_lru_cache_tracing(self):
from functools import lru_cache

counter = 0

@lru_cache
def cached_fn(x):
nonlocal counter
counter += 1
return x + 1

compiled_fn = torch.compile(cached_fn, backend="eager")

t = torch.randn(2, 2)
result1 = compiled_fn(t)
self.assertEqual(counter, 1)

result2 = compiled_fn(t)
self.assertEqual(counter, 1)
self.assertEqual(result1, result2)

def test_dont_aggressively_write_assert(self):
record_graph = torch._dynamo.testing.EagerAndRecordGraphs()

Expand Down Expand Up @@ -5431,6 +5454,7 @@ def forward(self, x):
mod = Mod()
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
x = torch.randn(4)

self.assertEqual(mod(x), opt_mod(x))

def test_enum(self):
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import itertools
import sys
import types
import warnings
from collections.abc import Sequence
from types import FunctionType
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
Expand Down Expand Up @@ -445,6 +446,7 @@ def call_function(
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# Handle patch_dynamo_config call

if self.fn is torch._dynamo.patch_dynamo_config:
try:
args_const = [arg.as_python_constant() for arg in args]
Expand Down Expand Up @@ -1534,6 +1536,12 @@ def call_function(
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if hasattr(self.wrapper_obj, "cache_info"):
warnings.warn(
"Dynamo detected a call to a `functools.lru_cache` wrapped function."
"Dynamo currently ignores `functools.lru_cache` and directly traces the wrapped function."
"`functools.lru_cache` wrapped functions that read outside state may not be traced soundly."
)
return variables.UserFunctionVariable(
polyfills.getattr_and_trace
).call_function(
Expand Down
Loading
0