diff --git a/torch/__init__.py b/torch/__init__.py index 475dd80d7ef42..501aff7ceaa56 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2290,7 +2290,7 @@ def compiled_with_cxx11_abi() -> builtins.bool: from torch.utils.dlpack import from_dlpack, to_dlpack -def skip_frame_if_max_graphs() -> None: +def has_hit_max_graphs() -> bool: """ If we have hit a user specified max number of graphs, skip this frame. """ @@ -2301,8 +2301,8 @@ def skip_frame_if_max_graphs() -> None: return GraphsCompiledState.increment() - if GraphsCompiledState.get_num_graphs() > builtins.int(max_graphs): - raise torch._dynamo.exc.SkipFrame(f"Hit max graph limit: {max_graphs}") + return GraphsCompiledState.get_num_graphs() > builtins.int(max_graphs) + # raise torch._dynamo.exc.SkipFrame(f"Hit max graph limit: {max_graphs}") class _TorchCompileInductorWrapper: @@ -2377,7 +2377,9 @@ def apply_options(self, options: _Optional[dict[str, _Any]]): def __call__(self, model_, inputs_): from torch._inductor.compile_fx import compile_fx - skip_frame_if_max_graphs() + if has_hit_max_graphs(): + return _TorchCompileWrapper("aot_eager", "default", {}, self.dynamic).__call__(model_, inputs_) + return compile_fx(model_, inputs_, config_patches=self.config) def get_compiler_config(self): @@ -2423,7 +2425,6 @@ def __eq__(self, other): ) def __call__(self, model_, inputs_): - skip_frame_if_max_graphs() return self.compiler_fn(model_, inputs_, **self.kwargs) def reset(self):