8000 [inductor] fix compile time regression by caching get_gpu_type (#128363) · pytorch/pytorch@8a09940 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8a09940

Browse files
wanchaolpytorchmergebot
authored andcommitted
[inductor] fix compile time regression by caching get_gpu_type (#128363)
We observed signficant compile time regression in torchtitan when turning on 2D parallel + torch.compile recently. So I decided to get a deeper understanding why. It turns out this is affecting **all the trainings** that have functional collectives captured in the graph, not only 2D parallel (2D parallel was just the job that happen to have collectives captured in the TP region). The root cause is because when doing inductor lowering, we are calling the comm analysis pass to get a estimated collective time for each collective node in the graph, for each call to check the collective node, we are calling `get_gpu_type()`, which under the hood calls a `torch.utils.collect_env.run` to get the GPU info. However, this call is super expensive! The reason is that this call effectively spawns a new process and call `nvidia-smi` to get the GPU info, so the cost is **linear** to the number of collective nodes in the graph. see https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py#L75 The fix is to add a lru cache to the function, so that we only call this once and reuse the cached results afterwards torchtitan benchmark shows: * before this fix: 2D parallel + fp8 compile time: 6min + * after this fix: 2D parallel + fp8 compile time: 2min 48s (more than 100% improvement) There're more room to improve the compile time, but this PR is trying to fix the biggest regression I found so far. Pull Request resolved: #128363 Approved by: https://github.com/yf225
1 parent 1d233b8 commit 8a09940

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torch/_inductor/comm_analysis.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import math
23
from enum import IntEnum
34

@@ -22,6 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum):
2223
HOPPER = 2
2324

2425

26+
@functools.lru_cache
2527
def get_gpu_type() -> NVIDIA_GPU_TYPE:
2628
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
2729
if "V100" in gpu_info:

0 commit comments

Comments
 (0)
0