File tree 2 files changed +33
-7
lines changed
2 files changed +33
-7
lines changed Original file line number Diff line number Diff line change @@ -721,7 +721,7 @@ def make_run_fn(
721
721
workspace_ptr = c_void_p (self .workspace .data_ptr ())
722
722
723
723
# Generate partial function.
724
- return functools .partial (
724
+ ret = functools .partial (
725
725
run_method ,
726
726
* args ,
727
727
* self .extra_args ,
@@ -730,6 +730,18 @@ def make_run_fn(
730
730
stream_ptr ,
731
731
)
732
732
733
+ # sanity check to make sure we cleanup run fn properly
734
+ try :
735
+ ret ()
736
+ except RuntimeError as e :
737
+ err_msg = str (e )
738
+ def dummy_function ():
739
+ raise RuntimeError (err_msg )
740
+ self .cleanup_run_fn ()
741
+ return dummy_function
742
+
743
+ return ret
744
+
733
745
def update_workspace_size (self ) -> None :
734
746
if self ._workspace_size_updated :
735
747
return
Original file line number Diff line number Diff line change @@ -2265,16 +2265,30 @@ def benchmark_choices(
2265
2265
log .warning ("Not yet implemented: %s" , e )
2266
2266
timing = float ("inf" )
2267
2267
except RuntimeError as e :
2268
+ from torch ._inductor .codegen .cuda .cuda_kernel import CUDATemplateCaller
2269
+
2270
+ if not isinstance (choice , CUDATemplateCaller ):
2271
+ log .error (
2272
+ "CUDA compilation error during autotuning: \n %s. \n Ignoring this choice." ,
2273
+ e ,
2274
+ )
2268
2275
msg = str (e )
2269
2276
if "invalid argument" in msg :
2270
2277
msg += "\n \n This may mean this GPU is too small for max_autotune mode.\n \n "
2278
+ elif "illegal memory access" in msg :
2279
+ msg += "\n \n Either error in template or triton bug.\n "
2280
+
2281
+ if isinstance (choice , CUDATemplateCaller ):
2282
+ log .debug (
2283
+ "Runtime error during autotuning: \n %s. \n Ignoring this choice." ,
2284
+ msg ,
2285
+ exc_info = True ,
2286
+ )
2271
2287
else :
2272
- if "illegal memory access" in msg :
2273
- msg += "\n \n Either error in template or triton bug.\n "
2274
- log .error (
2275
- "Runtime error during autotuning: \n %s. \n Ignoring this choice." ,
2276
- msg ,
2277
- )
2288
+ log .error (
2289
+ "Runtime error during autotuning: \n %s. \n Ignoring this choice." ,
2290
+ msg ,
2291
+ )
2278
2292
timing = float ("inf" )
2279
2293
except AssertionError as e :
2280
2294
raise AssertionError ( # noqa: B904
You can’t perform that action at this time.
0 commit comments