@@ -2290,7 +2290,7 @@ def compiled_with_cxx11_abi() -> builtins.bool:
2290
2290
from torch .utils .dlpack import from_dlpack , to_dlpack
2291
2291
2292
2292
2293
- def skip_frame_if_max_graphs () -> None :
2293
+ def has_hit_max_graphs () -> bool :
2294
2294
"""
2295
2295
If we have hit a user specified max number of graphs, skip this frame.
2296
2296
"""
@@ -2301,8 +2301,8 @@ def skip_frame_if_max_graphs() -> None:
2301
2301
return
2302
2302
2303
2303
GraphsCompiledState .increment ()
2304
- if GraphsCompiledState .get_num_graphs () > builtins .int (max_graphs ):
2305
- raise torch ._dynamo .exc .SkipFrame (f"Hit max graph limit: { max_graphs } " )
2304
+ return GraphsCompiledState .get_num_graphs () > builtins .int (max_graphs )
2305
+ # raise torch._dynamo.exc.SkipFrame(f"Hit max graph limit: {max_graphs}")
2306
2306
2307
2307
2308
2308
class _TorchCompileInductorWrapper :
@@ -2377,7 +2377,9 @@ def apply_options(self, options: _Optional[dict[str, _Any]]):
2377
2377
def __call__ (self , model_ , inputs_ ):
2378
2378
from torch ._inductor .compile_fx import compile_fx
2379
2379
2380
- skip_frame_if_max_graphs ()
2380
+ if has_hit_max_graphs ():
2381
+ return _TorchCompileWrapper ("aot_eager" , "default" , {}, self .dynamic ).__call__ (model_ , inputs_ )
2382
+
2381
2383
return compile_fx (model_ , inputs_ , config_patches = self .config )
2382
2384
2383
2385
def get_compiler_config (self ):
@@ -2423,7 +2425,6 @@ def __eq__(self, other):
2423
2425
)
2424
2426
2425
2427
def __call__ (self , model_ , inputs_ ):
2426
- skip_frame_if_max_graphs ()
2427
2428
return self .compiler_fn (model_ , inputs_ , ** self .kwargs )
2428
2429
2429
2430
def reset (self ):
0 commit comments