8000 [dynamo] No eager code wrapping on TORCHDYNAMO_DISABLE · pytorch/pytorch@5073493 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5073493

Browse files
committed
[dynamo] No eager code wrapping on TORCHDYNAMO_DISABLE
ghstack-source-id: e4ecac2 Pull Request resolved: #148618
1 parent 098494e commit 5073493

File tree

5 files changed

+173
-12
lines changed

5 files changed

+173
-12
lines changed

test/dynamo/test_decorators.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,67 @@ def fail_backend(gm, ex):
14591459
with torch.compiler.set_stance("default", force_backend=fail_backend):
14601460
f(torch.randn(3, 3))
14611461

1462+
@patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"})
1463+
def test_compiler_disabled_stack_trace(self):
1464+
class MyException(Exception):
1465+
pass
1466+
1467+
# 1. As decorators
1468+
with self.assertRaises(MyException):
1469+
1470+
@torch._dynamo.run
1471+
def fn():
1472+
raise MyException
1473+
1474+
fn()
1475+
1476+
with self.assertRaises(MyException):
1477+
1478+
@torch._dynamo.disable
1479+
def fn():
1480+
raise MyException
1481+
1482+
fn()
1483+
1484+
with self.assertRaises(MyException):
1485+
1486+
@torch.compile(backend="eager")
1487+
def fn():
1488+
raise MyException
1489+
1490+
fn()
1491+
1492+
# 2. As wrappers
1493+
def create_fn():
1494+
def fn():
1495+
raise MyException
1496+
1497+
return fn
1498+
1499+
with self.assertRaises(MyException):
1500+
torch._dynamo.run(create_fn())()
1501+
1502+
with self.assertRaises(MyException):
1503+
torch._dynamo.disable(create_fn())()
1504+
1505+
with self.assertRaises(MyException):
1506+
torch.compile(create_fn(), backend="eager")()
1507+
1508+
# 3. As deferred wrappers
1509+
with self.assertRaises(MyException):
1510+
torch._dynamo.run()(create_fn())()
1511+
1512+
with self.assertRaises(MyException):
1513+
torch._dynamo.disable()(create_fn())()
1514+
1515+
with self.assertRaises(MyException):
1516+
torch.compile(backend="eager")(create_fn())()
1517+
1518+
# Verify we didn't accidentally wrap user code
1519+
self.assertEqual(counters["eval_frame"]["RunOnlyContext"], 0)
1520+
self.assertEqual(counters["eval_frame"]["DisableContext"], 0)
1521+
self.assertEqual(counters["eval_frame"]["OptimizeContext"], 0)
1522+
14621523

14631524
if __name__ == "__main__":
14641525
from torch._dynamo.test_case import run_tests

test/dynamo/test_misc.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4009,6 +4009,38 @@ def fn3(x):
40094009
self.assertEqual(cnts3.frame_count, 1)
40104010
self.assertEqual(cnts3.op_count, 4)
40114011

4012+
@patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"})
4013+
def test_nested_optimize_decorator_disabled(self):
4014+
class MyException(Exception):
4015+
pass
4016+
4017+
cnts2 = torch._dynamo.testing.CompileCounter()
4018+
cnts3 = torch._dynamo.testing.CompileCounter()
4019+
4020+
@torch._dynamo.run()
4021+
def fn1(x):
4022+
raise MyException
4023+
4024+
@torch.compile(backend=cnts2, fullgraph=True)
4025+
def fn2(x):
4026+
return fn1(x) + 1
4027+
4028+
@torch.compile(backend=cnts3, fullgraph=True)
4029+
def fn3(x):
4030+
return torch.relu(fn2(x))
4031+
4032+
try:
4033+
fn3(torch.randn(4, 5))
4034+
except MyException as e:
4035+
self.assertNotIn("eval_frame", traceback.format_exc())
4036+
4037+
self.assertEqual(cnts2.frame_count, 0)
4038+
self.assertEqual(cnts3.frame_count, 0)
4039+
self.assertEqual(cnts3.op_count, 0)
4040+
self.assertEqual(counters["eval_frame"]["RunOnlyContext"], 0)
4041+
self.assertEqual(counters["eval_frame"]["DisableContext"], 0)
4042+
self.assertEqual(counters["eval_frame"]["OptimizeContext"], 0)
4043+
40124044
def test_nested_optimize_run(self):
40134045
cnts = torch._dynamo.testing.CompileCounter()
40144046

@@ -4027,6 +4059,29 @@ def fn(x):
40274059
fn(torch.randn(4, 4, 4))
40284060
self.assertEqual(cnts.frame_count, 2)
40294061

4062+
@patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"})
4063+
def test_nested_optimize_run_disabled(self):
4064+
cnts = torch._dynamo.testing.CompileCounter()
4065+
4066+
@torch.compile(backend=cnts, fullgraph=True)
4067+
def fn(x):
4068+
return torch.relu(torch.cos(x) + torch.sin(x))
4069+
4070+
fn(torch.randn(4))
4071+
self.assertEqual(cnts.frame_count, 0)
4072+
4073+
fn(torch.randn(4, 4))
4074+
self.assertEqual(cnts.frame_count, 0)
4075+
4076+
# Test that run works on a decorated fn
4077+
fn = torch._dynamo.run(fn)
4078+
fn(torch.randn(4, 4, 4))
4079+
self.assertEqual(cnts.frame_count, 0)
4080+
4081+
self.assertEqual(counters["eval_frame"]["RunOnlyContext"], 0)
4082+
self.assertEqual(counters["eval_frame"]["DisableContext"], 0)
4083+
self.assertEqual(counters["eval_frame"]["OptimizeContext"], 0)
4084+
40304085
def test_nested_optimize(self):
40314086
cnts1 = torch._dynamo.testing.CompileCounter()
40324087
cnts2 = torch._dynamo.testing.CompileCounter()
@@ -4058,6 +4113,42 @@ def fn(x):
40584113
torch._dynamo.run()(fn2)(torch.randn(4))
40594114
self.assertEqual(cnts2.frame_count, 0)
40604115

4116+
@patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"})
4117+
def test_nested_optimize_disabled(self):
4118+
cnts1 = torch._dynamo.testing.CompileCounter()
4119+
cnts2 = torch._dynamo.testing.CompileCounter()
4120+
4121+
def fn(x):
4122+
return torch.relu(torch.cos(x) + torch.sin(x))
4123+
4124+
fn1 = torch.compile(fn, backend=cnts1, fullgraph=True)
4125+
fn2 = torch.compile(fn1, backend=cnts2, fullgraph=True)
4126+
4127+
# The first optimize in the nesting should be ignored
4128+
fn2(torch.randn(4))
4129+
self.assertEqual(cnts2.frame_count, 0)
4130+
self.assertEqual(cnts1.frame_count, 0)
4131+
4132+
# Since the fn code object is already compiled, calling fn1 should
4133+
# directly call the compiled_fn callable.
4134+
torch._dynamo.run()(fn1)(torch.randn(4))
4135+
self.assertEqual(cnts1.frame_count, 0)
4136+
4137+
# Test same behavior by reversing the calls
4138+
torch._dynamo.reset()
4139+
cnts1 = torch._dynamo.testing.CompileCounter()
4140+
cnts2 = torch._dynamo.testing.CompileCounter()
4141+
fn1 = torch.compile(fn, backend=cnts1, fullgraph=True)
4142+
fn2 = torch.compile(fn1, backend=cnts2, fullgraph=True)
4143+
fn1(torch.randn(4))
4144+
self.assertEqual(cnts1.frame_count, 0)
4145+
torch._dynamo.run()(fn2)(torch.randn(4))
4146+
self.assertEqual(cnts2.frame_count, 0)
4147+
4148+
self.assertEqual(counters["eval_frame"]["RunOnlyContext"], 0)
4149+
self.assertEqual(counters["eval_frame"]["DisableContext"], 0)
4150+
self.assertEqual(counters["eval_frame"]["OptimizeContext"], 0)
4151+
40614152
def test_torch_size(self):
40624153
cnts = torch._dynamo.testing.CompileCounter()
40634154

torch/_dynamo/decorators.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from . import trace_rules, variables
2323
from .comptime import comptime
2424
from .eval_frame import (
25+
_NullDecorator,
2526
_set_stance,
2627
DisableContext,
2728
DynamoStance,
@@ -31,7 +32,7 @@
3132
)
3233
from .exc import IncorrectUsage
3334
from .external_utils import is_compiling
34-
from .utils import is_function
35+
from .utils import is_compiler_disabled, is_function
3536

3637

3738
if TYPE_CHECKING:
@@ -58,11 +59,12 @@
5859

5960
def run(fn=None):
6061
"""Don't do any dynamic compiles, just use prior optimizations"""
62+
ctx = _NullDecorator() if is_compiler_disabled() else RunOnlyContext()
6163
if fn is not None:
6264
fn = innermost_fn(fn)
6365
assert callable(fn)
64-
return RunOnlyContext()(fn)
65-
return RunOnlyContext()
66+
return ctx(fn)
67+
return ctx
6668

6769

6870
def disable(fn=None, recursive=True):
@@ -76,11 +78,12 @@ def disable(fn=None, recursive=True):
7678
but still process recursively invoked frames.
7779
"""
7880
if recursive:
81+
ctx = _NullDecorator() if is_compiler_disabled() else DisableContext()
7982
if fn is not None:
8083
fn = innermost_fn(fn)
8184
assert callable(fn)
82-
return DisableContext()(fn)
83-
return DisableContext()
85+
return ctx(fn)
86+
return ctx
8487
else:
8588
return skip(fn)
8689

torch/_dynamo/eval_frame.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import functools
3131
import inspect
3232
import logging
33-
import os
3433
import sys
3534
import sysconfig
3635
import textwrap
@@ -97,7 +96,7 @@
9796
)
9897
from .hooks import Hooks
9998
from .mutation_guard import install_generation_tagging_init
100-
from .utils import common_constant_types, compile_times
99+
from .utils import common_constant_types, compile_times, counters, is_compiler_disabled
101100

102101

103102
if TYPE_CHECKING:
@@ -742,6 +741,8 @@ def __init__(
742741
Callable[[], Union[OptimizeContext, _NullDecorator]]
743742
] = None,
744743
) -> None:
744+
counters["eval_frame"]["OptimizeContext"] += 1
745+
745746
def on_enter():
746747
install_generation_tagging_init()
747748

@@ -783,6 +784,8 @@ def __reduce__(self):
783784

784785
class RunOnlyContext(_TorchDynamoContext):
785786
def __init__(self) -> None:
787+
counters["eval_frame"]["RunOnlyContext"] += 1
788+
786789
# cudagraph trees relies on generation increment
787790
def on_enter():
788791
torch._dynamo.mutation_guard.GenerationTracker.generation += 1
@@ -795,6 +798,7 @@ def __reduce__(self):
795798

796799
class DisableContext(_TorchDynamoContext):
797800
def __init__(self) -> None:
801+
counters["eval_frame"]["DisableContext"] += 1
798802
super().__init__(callback=None)
799803

800804
def __call__(self, fn):
@@ -988,11 +992,7 @@ def toy_example(a, b): ...
988992
# easier to understand UX at the cost of a little more plumbing on our end.
989993
hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn)
990994
torch._C._log_api_usage_once("torch._dynamo.optimize")
991-
if (
992-
disable
993-
or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1"
994-
or (not justknobs_check("pytorch/compiler:enable_dynamo"))
995-
):
995+
if disable or is_compiler_disabled():
996996
return _NullDecorator()
997997

998998
backend = get_compiler_fn(backend)

torch/_dynamo/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4471,3 +4471,9 @@ def get_optimize_ddp_mode():
44714471
f"Invalid dynamo config optimize_ddp value {mode=}"
44724472
)
44734473
return mode
4474+
4475+
4476+
def is_compiler_disabled():
4477+
return os.environ.get("TORCHDYNAMO_DISABLE", "") == "1" or (
4478+
not justknobs_check("pytorch/compiler:enable_dynamo")
4479+
)

0 commit comments

Comments
 (0)
0