8000 [cutlass backend] Reduce log level for cutlass compilation error (#15… · pytorch/pytorch@76f182f · GitHub
[go: up one dir, main page]

Skip to content

Commit 76f182f

Browse files
henrylhtsangpytorchmergebot
authored andcommitted
[cutlass backend] Reduce log level for cutlass compilation error (#153397)
Differential Revision: [D74596410](https://our.internmc.facebook.com/intern/diff/D74596410/) This change should only affect cutlass backend. We realize that we are going to have Cuda compilation errors, and we do a really good job handling them and caching them. So reduce the logging levels there. Pull Request resolved: #153397 Approved by: https://github.com/ColinPeppler, https://github.com/Skylion007
1 parent 3bde364 commit 76f182f

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

torch/_inductor/select_algorithm.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -2076,14 +2076,15 @@ def precompile_with_captured_stdout(choice) -> tuple[None, int]:
20762076
return None, elapsed_ns // 1000
20772077

20782078
def on_complete(future):
2079-
_, precompile_elapsed_us = future.result()
2080-
elapsed_seconds = precompile_elapsed_us / 1e6
2081-
elapsed_times[future] = elapsed_seconds
2082-
log.debug(
2083-
"Precompilation complete for future: %s, elapsed time: %.02fs",
2084-
future,
2085-
elapsed_seconds,
2086-
)
2079+
if not future.exception():
2080+
_, precompile_elapsed_us = future.result()
2081+
elapsed_seconds = precompile_elapsed_us / 1e6
2082+
elapsed_times[future] = elapsed_seconds
2083+
log.debug(
2084+
"Precompilation complete for future: %s, elapsed time: %.02fs",
2085+
future,
2086+
elapsed_seconds,
2087+
)
20872088

20882089
executor = ThreadPoolExecutor(max_workers=num_workers)
20892090
async_compile = torch._inductor.async_compile.AsyncCompile()
@@ -2130,9 +2131,23 @@ def wait_on_futures():
21302131
timeout=precompilation_timeout_seconds,
21312132
):
21322133
if e := future.exception():
2133-
log.error(
2134-
"Exception %s for benchmark choice %s", e, futures[future]
2134+
from torch._inductor.codegen.cuda.cuda_kernel import (
2135+
CUDATemplateCaller,
21352136
)
2137+
2138+
if isinstance(e, CUDACompileError) and isinstance(
2139+
futures[future], CUDATemplateCaller
2140+
):
2141+
log.debug(
2142+
"Exception %s for benchmark choice %s",
2143+
e,
2144+
futures[future],
2145+
exc_info=True,
2146+
)
2147+
else:
2148+
log.error(
2149+
"Exception %s for benchmark choice %s", e, futures[future]
2150+
)
21362151
else:
21372152
counters["inductor"]["select_algorithm_num_precompiles"] += 1
21382153
log.info(
@@ -2238,10 +2253,13 @@ def benchmark_choices(
22382253
try:
22392254
timing = cls.benchmark_choice(choice, autotune_args)
22402255
except CUDACompileError as e:
2241-
log.error(
2242-
"CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
2243-
str(e),
2244-
)
2256+
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
2257+
2258+
if not isinstance(choice, CUDATemplateCaller):
2259+
log.error(
2260+
"CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.",
2261+
e,
2262+
)
22452263
timing = float("inf")
22462264
except NotImplementedError as e:
22472265
log.warning("Not yet implemented: %s", e)

0 commit comments

Comments
 (0)
0