From 9cc382e06b3285af62b6085b4d5ab862326121ef Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 12 May 2025 13:37:52 -0700 Subject: [PATCH] defer to aot eager instead of skip frame [ghstack-poisoned] --- torch/__init__.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index 475dd80d7ef424..501aff7ceaa565 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2290,7 +2290,7 @@ def compiled_with_cxx11_abi() -> builtins.bool: from torch.utils.dlpack import from_dlpack, to_dlpack -def skip_frame_if_max_graphs() -> None: +def has_hit_max_graphs() -> bool: """ If we have hit a user specified max number of graphs, skip this frame. """ @@ -2301,8 +2301,8 @@ def skip_frame_if_max_graphs() -> None: return GraphsCompiledState.increment() - if GraphsCompiledState.get_num_graphs() > builtins.int(max_graphs): - raise torch._dynamo.exc.SkipFrame(f"Hit max graph limit: {max_graphs}") + return GraphsCompiledState.get_num_graphs() > builtins.int(max_graphs) + # raise torch._dynamo.exc.SkipFrame(f"Hit max graph limit: {max_graphs}") class _TorchCompileInductorWrapper: @@ -2377,7 +2377,9 @@ def apply_options(self, options: _Optional[dict[str, _Any]]): def __call__(self, model_, inputs_): from torch._inductor.compile_fx import compile_fx - skip_frame_if_max_graphs() + if has_hit_max_graphs(): + return _TorchCompileWrapper("aot_eager", "default", {}, self.dynamic).__call__(model_, inputs_) + return compile_fx(model_, inputs_, config_patches=self.config) def get_compiler_config(self): @@ -2423,7 +2425,6 @@ def __eq__(self, other): ) def __call__(self, model_, inputs_): - skip_frame_if_max_graphs() return self.compiler_fn(model_, inputs_, **self.kwargs) def reset(self):