-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Performance Regression nightly 03/11→03/12, on nanogpt speedrun #152823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Probably caused by triton upgrade. |
delta commits by
|
Hi-priority for regression. Our oncall should try to bisect this (cc @masnesral) |
So If I did the calculation correctly, this looks like a difference in perf of around 1.2%, right? @zou3519 that would seem to be well within the threshold before we note any regression on the dashboard, for example. Do we consider that size of regression to be actionable? (I can't even get within that threshold on back-to-back runs on my machine). |
The setup for the speedrun is fairly stable, this regression is over 1770+ training steps, and I was able to consistently repro runs within 0.3% when i tried it |
Yeah, sorry, got it. I didn't know this issue was referring to that thing. Thanks for sorting me out. |
Update:
Part of the regression is from this first triton update: #148705. There's another one later that I haven't isolated yet. Before: tlparse (https://github.com/pytorch/pytorch/tree/xmfan/pre_3945954741e2d37023c5d6954f9483008e0892f9) Delta:
|
Yep the regression is fairly reproducible. I actually reported the regression on a larger run with 5960 steps. Typical runtime std is within 0.1%. For example, these runs on 03/30 have a std <0.03%, so 1.2% is like 40σ. Note that the code is different from the 1770 steps small run. |
Profiling the first 20 iterations, we see about a 2% regression per iteration from the bump. Half of it seems to come from |
oh this matches my speculation:
(I planned to do a profiling to verify my speculation on Monday but I was a bit busy and forgot to do so...) |
@davidberard98 is digging further into the triton side |
https://gist.github.com/davidberard98/f10db5520c96111254e614b53db9f501 Here are some scripts for separating out the affected kernels. I'm getting pulled into some other tasks but I'll hopefully return to this in a day or two |
At the TTIR level, the main difference I see is that some loop-invariant code has been put into the loop in triton 3.3 (whereas in 3.2, the code was pulled outside of the loop). However, I would imagine that the LICM pass should be able to recognize this and pull the loop-invariant code out of the loop. Here's the TTIR (both files in the same gist): https://gist.github.com/davidberard98/5e77aa6e0206b20acee4a21535fa3ba3 |
It turns out that the individual kernels I linked in https://gist.github.com/davidberard98/f10db5520c96111254e614b53db9f501 don't repro a difference when I test them. |
🐛 Describe the bug
code: https://gist.github.com/YouJiacheng/687efdab59a3c3b4ad89864804bd918a
I manually applied changes from #152641
03/10: 1469.0-1470.4s (3 runs)
03/11: 1469.4-1470.5s
03/12: 1486.0-1487.4s (a few runs)
03/15: ≈1487.5s (a single run)
FWD diffs (03/10 vs. 03/15): https://www.diffchecker.com/bLNEBIii/
BWD diffs (03/10 vs. 03/15): https://www.diffchecker.com/bbiVBsPU/
Bisection 03/12
runtime 1486.0-1487.4s (a few runs)
Inductor output code is identical to 03/15
Bisection 03/11
runtime 1469.4-1470.5s
Inductor output code:
BWD is identical to 03/10
FWD diffs (~no diffs): https://www.diffchecker.com/wQxaVYL3/
Optimizer diffs (~no diffs): https://www.diffchecker.com/Og8kGihp/ https://www.diffchecker.com/N2qJ4DyA/
Versions 03/10
Collecting environment information...
PyTorch version: 2.7.0.dev20250310+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39
Python version: 3.12.9 (main, Feb 5 2025, 19:10:45) [Clang 19.1.6 ] (64-bit runtime)
Python platform: Linux-5.4.250-2-velinux1u1-amd64-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version: 535.129.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 168
On-line CPU(s) list: 0-161
Off-line CPU(s) list: 162-167
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8457C
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 42
Socket(s): 2
Stepping: 8
BogoMIPS: 5199.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3.9 MiB (84 instances)
L1i cache: 2.6 MiB (84 instances)
L2 cache: 168 MiB (84 instances)
L3 cache: 195 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-83
NUMA node1 CPU(s): 84-167
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.24.3
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-triton==3.2.0+git4b3bb1f8
[pip3] torch==2.7.0.dev20250310+cu126
[conda] Could not collect
Versions 03/15
Collecting environment information...
PyTorch version: 2.8.0.dev20250315+cu126
[omitted]
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.24.3
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0.dev20250315+cu126
[conda] Could not collect
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bertmaher @int3 @davidberard98 @nmacchioni @chenyang78 @embg @peterbell10 @aakhundov @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng
The text was updated successfully, but these errors were encountered: