10000 defer to aot eager instead of skip frame · pytorch/pytorch@fb3956c · GitHub
[go: up one dir, main page]

Skip to content

Commit fb3956c

Browse files
committed
defer to aot eager instead of skip frame
ghstack-source-id: ce98f2d Pull Request resolved: #153409
1 parent 99acb5e commit fb3956c

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

torch/__init__.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -2290,7 +2290,7 @@ def compiled_with_cxx11_abi() -> builtins.bool:
22902290
from torch.utils.dlpack import from_dlpack, to_dlpack
22912291

22922292

2293-
def skip_frame_if_max_graphs() -> None:
2293+
def has_hit_max_graphs() -> bool:
22942294
"""
22952295
If we have hit a user specified max number of graphs, skip this frame.
22962296
"""
@@ -2301,8 +2301,8 @@ def skip_frame_if_max_graphs() -> None:
23012301
return
23022302

23032303
GraphsCompiledState.increment()
2304-
if GraphsCompiledState.get_num_graphs() > builtins.int(max_graphs):
2305-
raise torch._dynamo.exc.SkipFrame(f"Hit max graph limit: {max_graphs}")
2304+
return GraphsCompiledState.get_num_graphs() > builtins.int(max_graphs)
2305+
# raise torch._dynamo.exc.SkipFrame(f"Hit max graph limit: {max_graphs}")
23062306

23072307

23082308
class _TorchCompileInductorWrapper:
@@ -2377,7 +2377,9 @@ def apply_options(self, options: _Optional[dict[str, _Any]]):
23772377
def __call__(self, model_, inputs_):
23782378
from torch._inductor.compile_fx import compile_fx
23792379

2380-
skip_frame_if_max_graphs()
2380+
if has_hit_max_graphs():
2381+
return _TorchCompileWrapper("aot_eager", "default", {}, self.dynamic).__call__(model_, inputs_)
2382+
23812383
return compile_fx(model_, inputs_, config_patches=self.config)
23822384

23832385
def get_compiler_config(self):
@@ -2423,7 +2425,6 @@ def __eq__(self, other):
24232425
)
24242426

24252427
def __call__(self, model_, inputs_):
2426-
skip_frame_if_max_graphs()
24272428
return self.compiler_fn(model_, inputs_, **self.kwargs)
24282429

24292430
def reset(self):

0 commit comments

Comments
 (0)
0