8000 [Dynamo] Fix guards for script_if_tracing or lru_cache fn with defaul… · pytorch/pytorch@5a0a964 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a0a964

Browse files
yanboliangpytorchmergebot
authored andcommitted
[Dynamo] Fix guards for script_if_tracing or lru_cache fn with default args (#120390)
Fixes #120387 Pull Request resolved: #120390 Approved by: https://github.com/anijain2305
1 parent 55b5908 commit 5a0a964

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

test/dynamo/test_functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ def inline_unused(x):
9191
return x + 5.6
9292

9393

94+
@functools.lru_cache
95+
def inline_lru_cache_fn_with_default_args(x, y, _=None):
96+
return torch.sin(x * y)
97+
98+
99+
@torch.jit.script_if_tracing
100+
def inline_script_if_tracing_fn_with_default_args(x, y, _=None):
101+
return torch.cos(x * y)
102+
103+
94104
class FunctionTests(torch._dynamo.test_case.TestCase):
95105
@make_test
96106
def test_inline_jit_annotations(x):
@@ -99,6 +109,14 @@ def test_inline_jit_annotations(x):
99109
x = inline_unused(x)
100110
return
101111

112+
@make_test
113+
def test_inline_script_if_tracing_fn_with_default_args(a, b):
114+
return inline_script_if_tracing_fn_with_default_args(a, 2, b)
115+
116+
@make_test
117+
def test_inline_lru_cache_fn_with_default_args(a, b):
118+
return inline_lru_cache_fn_with_default_args(a, 2, b)
119+
102120
@make_test
103121
def test_add(a, b):
104122
return a + b

torch/_dynamo/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,14 +555,27 @@ def is_function(value):
555555

556556

557557
def unwrap_if_wrapper(fn):
558+
return unwrap_with_attr_name_if_wrapper(fn)[0]
559+
560+
561+
def unwrap_with_attr_name_if_wrapper(fn):
562+
# unpack @functools.lru_cache wrapped function
558563
if isinstance(fn, functools._lru_cache_wrapper):
559564
fn = inspect.getattr_static(fn, "__wrapped__")
565+
attr_name = "__wrapped__"
560566
# unpack @torch._dynamo.optimize()(fn) wrapped function
561-
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
567+
elif is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False):
568+
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
569+
attr_name = "_torchdynamo_inline"
562570
# unpack torch.jit.script_if_tracing
563-
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
571+
elif is_function(fn) and inspect.getattr_static(
572+
fn, "__script_if_tracing_wrapper", False
573+
):
564574
fn = inspect.getattr_static(fn, "__original_fn", fn)
565-
return fn
575+
attr_name = "__original_fn"
576+
else:
577+
attr_name = None
578+
return fn, attr_name
566579

567580

568581
def is_numpy_ndarray(value):

torch/_dynamo/variables/builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
tuple_iterator,
7878
tuple_iterator_getitem,
7979
tuple_iterator_len,
80-
unwrap_if_wrapper,
80+
unwrap_with_attr_name_if_wrapper,
8181
wrap_fake_exception,
8282
)
8383

@@ -709,7 +709,11 @@ def build_key_value(i, k, v):
709709
self.install_guards(GuardBuilder.FUNCTION_MATCH)
710710
return TorchCtxManagerClassVariable(value, source=self.source)
711711
elif is_function_or_wrapper(value):
712-
value = unwrap_if_wrapper(value)
712+
value, attr_name = unwrap_with_attr_name_if_wrapper(value)
713+
# For these wrappers, Dynamo points to the wrapped function,
714+
# so source needs to be updated as well.
715+
if attr_name is not None:
716+
self.source = AttrSource(self.source, attr_name)
713717
return trace_rules.lookup(value).create_with_source(
714718
value, source=self.source
715719
)

0 commit comments

Comments
 (0)
0