29
29
from torch ._dynamo .device_interface import get_interface_for_device
30
30
from torch ._dynamo .testing import rand_strided
31
31
from torch ._dynamo .utils import counters , dynamo_timed , identity , preserve_rng_state
32
+ from torch ._inductor .codegen .cuda .cuda_kernel import CUDATemplateCaller
32
33
from torch ._inductor .utils import clear_on_fresh_inductor_cache
33
34
from torch .utils ._filelock import FileLock
34
35
from torch .utils ._ordered_set import OrderedSet
@@ -1824,8 +1825,6 @@ def __call__(
1824
1825
precompilation_timeout_seconds : int = 60 * 60 ,
1825
1826
return_multi_template = False ,
1826
1827
):
1827
- from .codegen .cuda .cuda_kernel import CUDATemplateCaller
1828
-
1829
1828
# Templates selected with input_gen_fns require specific input data to avoid IMA
1830
1829
# Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
1831
1830
# TODO(jgong5): support multi-template on CPU
@@ -2131,10 +2130,6 @@ def wait_on_futures():
2131
2130
timeout = precompilation_timeout_seconds ,
2132
2131
):
2133
2132
if e := future .exception ():
2134
- from torch ._inductor .codegen .cuda .cuda_kernel import (
2135
- CUDATemplateCaller ,
2136
- )
2137
-
2138
2133
if isinstance (e , CUDACompileError ) and isinstance (
2139
2134
futures [future ], CUDATemplateCaller
2140
2135
):
@@ -2253,8 +2248,6 @@ def benchmark_choices(
2253
2248
try :
2254
2249
timing = cls .benchmark_choice (choice , autotune_args )
2255
2250
except CUDACompileError as e :
2256
- from torch ._inductor .codegen .cuda .cuda_kernel import CUDATemplateCaller
2257
-
2258
2251
if not isinstance (choice , CUDATemplateCaller ):
2259
2252
log .error (
2260
2253
"CUDA compilation error during autotuning: \n %s. \n Ignoring this choice." ,
@@ -2265,8 +2258,6 @@ def benchmark_choices(
2265
2258
log .warning ("Not yet implemented: %s" , e )
2266
2259
timing = float ("inf" )
2267
2260
except RuntimeError as e :
2268
- from torch ._inductor .codegen .cuda .cuda_kernel import CUDATemplateCaller
2269
-
2270
2261
if not isinstance (choice , CUDATemplateCaller ):
2271
2262
log .error (
2272
2263
"CUDA runtime error during autotuning: \n %s. \n Ignoring this choice." ,
0 commit comments