8000 Performance Regression nightly 03/11→03/12, on nanogpt speedrun · Issue #152823 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Performance Regression nightly 03/11→03/12, on nanogpt speedrun #152823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
YouJiacheng opened this issue May 5, 2025 · 14 comments
Open

Performance Regression nightly 03/11→03/12, on nanogpt speedrun #152823

YouJiacheng opened this issue May 5, 2025 · 14 comments
Assignees
Labels
high priority module: flex attention module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module upstream triton Upstream Triton Issue

Comments

@YouJiacheng
Copy link
Contributor
YouJiacheng commented May 5, 2025

🐛 Describe the bug

code: https://gist.github.com/YouJiacheng/687efdab59a3c3b4ad89864804bd918a

I manually applied changes from #152641
03/10: 1469.0-1470.4s (3 runs)
03/11: 1469.4-1470.5s
03/12: 1486.0-1487.4s (a few runs)
03/15: ≈1487.5s (a single run)

FWD diffs (03/10 vs. 03/15): https://www.diffchecker.com/bLNEBIii/
BWD diffs (03/10 vs. 03/15): https://www.diffchecker.com/bbiVBsPU/

Bisection 03/12

runtime 1486.0-1487.4s (a few runs)
Inductor output code is identical to 03/15

Bisection 03/11

runtime 1469.4-1470.5s
Inductor output code:
BWD is identical to 03/10
FWD diffs (~no diffs): https://www.diffchecker.com/wQxaVYL3/
Optimizer diffs (~no diffs): https://www.diffchecker.com/Og8kGihp/ https://www.diffchecker.com/N2qJ4DyA/

Versions 03/10

Collecting environment information...
PyTorch version: 2.7.0.dev20250310+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.12.9 (main, Feb 5 2025, 19:10:45) [Clang 19.1.6 ] (64-bit runtime)
Python platform: Linux-5.4.250-2-velinux1u1-amd64-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.129.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 168
On-line CPU(s) list: 0-161
Off-line CPU(s) list: 162-167
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8457C
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 42
Socket(s): 2
Stepping: 8
BogoMIPS: 5199.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3.9 MiB (84 instances)
L1i cache: 2.6 MiB (84 instances)
L2 cache: 168 MiB (84 instances)
L3 cache: 195 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-83
NUMA node1 CPU(s): 84-167
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.24.3
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-triton==3.2.0+git4b3bb1f8
[pip3] torch==2.7.0.dev20250310+cu126
[conda] Could not collect

Versions 03/15

Collecting environment information...
PyTorch version: 2.8.0.dev20250315+cu126

[omitted]

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.24.3
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0.dev20250315+cu126
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bertmaher @int3 @davidberard98 @nmacchioni @chenyang78 @embg @peterbell10 @aakhundov @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

@YouJiacheng
Copy link
Contributor Author

Probably caused by triton upgrade.

@YouJiacheng YouJiacheng changed the title Performance Regression nightly 03/10→03/15, on nanogpt speedrun Performance Regression nightly 03/10→03/11, on nanogpt speedrun May 5, 2025
@YouJiacheng YouJiacheng changed the title Performance Regression nightly 03/10→03/11, on nanogpt speedrun Performance Regression nightly 03/11→03/12, on nanogpt speedrun May 5, 2025
@YouJiacheng
Copy link
Contributor Author

delta commits by
git log --oneline 295f2ed4d103017f7e19a7b8263ece606cd629db^..2a7e997b3f1805b077810d7fef87cabc4411eeea

2a7e997b3f test/dynamo/test_utils: Fix one broken test on different python versions (#148987)
e40a9e602b Add the max_autotune tests in the periodic jobs. (#143560)
60576419a2 Make dynamism code robust to NotImplementedException (#148823)
46f096bba6 Explicitly set use-ephemeral runners for windows nightly cpu test jobs (#149001)
5b60749e9e [cudagraph] add log for skip reasons (#148797)
98a2d905bf [MPSInductor] Fix large prod and sum reductions (#148975)
2dcdb4ba78 [ez] include config as part of __all__ in torch.compiler (#148978)
a6459afb0e [dynamic shapes] add backed_size_oblivious option (#148696)
53a1a022a9 [WIP] Initial implementation of Grouped Gemm API (#148531)
b98af95401 Fix DCP link (#148974)
6119ffc711 [ROCm][TunableOp] Fix TunableOp BLAS logging for online tuning case. (#148979)
e5fef8a08e [CI] Don't clean workspace when fetching repo (#147994)
72d9f88ef2 [release] Move triton pin to latest triton release/3.3.x (#148971)
e6ef0620cc Add shim.h C API to call dispatcher on our own aten ops (#148832)
cf19efd3d9 Support basic TorchBind in aot_compile and aoti_compile_and_package (#148506)
f69e58e8e8 [CI] Update crossvit_9_240 as pass (#148989)
b54cf1a281 Revert "[logging] Set compile_id in the CachingAutotuner during compilation so we have it for dynamo_timed logging (#148693)"
c18858d633 [MPS] Make `torch.mps.compile_shader` public (#148972)
abcec55532 gracefully handle `tokenize.TokenError` in funcname parser. Adds support for non-Python source (#148737)
73c8068cf8 [logging] Set compile_id in the CachingAutotuner during compilation so we have it for dynamo_timed logging (#148693)
5b8da17681 [cutlass backend] Add addmm and bmm tests for AOTI (#148929)
7b2ecb80eb [Codemod][AddExplicitStrictExportArg] caffe2/test/inductor (#148928)
61f9b50e09 [ROCm] Fix TORCH_CHECK for hdim 512 support added in AOTriton 0.9b (#148967)
971606befa Add a stable TORCH_LIBRARY to C shim (#148124)
4d10da731b [ROCm] CK Memory-Efficient Attention (attention bias support) (#147778)
a1cb67b69e [ROCm] Improve backwards indexing when stride is not one (#147630)
daff65d671 Correctly propagate exception to parent tx (#146502)
fb53e9e514 Add `__context/cause/suppress_context/traceback__` to Exception (#146499)
4e7d264cf8 Introduce `UserDefinedExceptionClassVariable` (#146504)
8d08b49015 Reland: [inductor] Simplify grid handling (#148305)
c916a8efc5 Revert "Use the device interface for detecting Triton availability (#139171)"
57ee821a41 fix dynamo ide (#148849)
883fb78c7e Update jinja2 version in requirements-gha-cache.txt
5ee9dbc0a1 Bump jinja2 from 3.1.5 to 3.1.6 in /.ci/docker (#148812)
a5f6b24d87 Remove outdated skipIfRocmVersionLessThan decorations (#148941)
ef6296e7f2 [PGNCCL] Launch kernel on current stream & remove `record_stream` entirely (#148590)
b366f33606 [MPSInductor] Prep for mutlistage reductions (#148969)
dcc502f376 [ROCm][TunableOp] Add bias data type to params signature. (#146227)
52acc1f955 [DSD] Update the document to mention the limitation of set_optimizer_state_dict (#148918)
e0d4c43ad1 Add env for disabling meta reference on functionalization. (#148822)
09029010e5 [inductor] Fix create_specialize_impl error in latest Triton (#148933)
16560d4e8f Revert "Refactor `test/test_torch.py` by moving testcase to `test_indexing.py` (#148875)"
3945954741 Bump triton pin. Add aarch64 triton build (#148705)
c983e1124c Revert "[WIP] Initial implementation of Grouped Gemm API (#148531)"
f1787ee0f7 [dynamo] Remove L scoping for recompilation messages (#148917)
992838e702 [dynamo][guards] Do not ID_MATCH on numpy tensors (#148923)
ee21ccc816 Skip ao_sparsity TestComposability for missing FBGEMM (#144146)
da4bb72a71 Backout D70075331 (#148824)
9ad64ce795 [triton 3.3] Forward-fix mm template selection logic (#148924)
2bcc3acb90 Update low prec codegen for div/mod (#142350)
41e4728f74 update types on dynamo configs (#146873)
1fcc4bc109 Don't look at TESTING_ONLY in fuzzer (#146870)
bed92a8523 [Window][Inductor UT] Fix for tempfile.NamedTemporaryFile(delete=True) not work on Windows. (#148632)
ecfbfe1603 [AOTI] Remove aoti_torch_cpu__weight_int4pack_mm_cpu_tensor (#148907)
940b60db97 Use the device interface for detecting Triton availability (#139171)
ff29791ed8 [WIP] Initial implementation of Grouped Gemm API (#148531)
621dadd4ca partitioner: when materializing unbacked tensor intermediates, apply hint to symbol, not expr (#144097)
8c45d44abb Skip distributed subprocess test internally as they don't work (#148909)
457ff9b7ae [reland][ca] side-effect free inital trace: compiled_args (#148376)
9fddbf3417 Update the comment (#148726)
0fa0a74095 Refactor `test/test_torch.py` by moving testcase to `test_indexing.py` (#148875)
c297c09a37 Fix invalid nested int guarding in broadcast_shapes() (#145957)
295f2ed4d1 Fix "invalid application of 'sizeof' to an incomplete type" (#148854)

@zou3519 zou3519 added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 5, 2025
@zou3519 zou3519 removed the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 5, 2025
@zou3519
Copy link
Contributor
zou3519 commented May 5, 2025

Hi-priority for regression. Our oncall should try to bisect this (cc @masnesral)

@masnesral
Copy link
Contributor

So If I did the calculation correctly, this looks like a difference in perf of around 1.2%, right? @zou3519 that would seem to be well within the threshold before we note any regression on the dashboard, for example. Do we consider that size of regression to be actionable? (I can't even get within that threshold on back-to-back runs on my machine).

@xmfan xmfan self-assigned this May 5, 2025
@xmfan
Copy link
Member
xmfan commented May 5, 2025

The setup for the speedrun is fairly stable, this regression is over 1770+ training steps, and I was able to consistently repro runs within 0.3% when i tried it

@masnesral
Copy link
Contributor

Yeah, sorry, got it. I didn't know this issue was referring to that thing. Thanks for sorting me out.

@masnesral masnesral added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels May 5, 2025
@xmfan
Copy link
Member
xmfan commented May 6, 2025

Update:

Part of the regression is from this first triton update: #148705. There's another one later that I haven't isolated yet.

Before: tlparse (https://github.com/pytorch/pytorch/tree/xmfan/pre_3945954741e2d37023c5d6954f9483008e0892f9)
After: tlparse (https://github.com/pytorch/pytorch/tree/xmfan/post_3945954741e2d37023c5d6954f9483008e0892f9)

Delta:

@YouJiacheng
Copy link
Contributor Author
YouJiacheng commented May 6, 2025

The setup for the speedrun is fairly stable, this regression is over 1770+ training steps, and I was able to consistently repro runs within 0.3% when i tried it

Yep the regression is fairly reproducible. I actually reported the regression on a larger run with 5960 steps. Typical runtime std is within 0.1%.

For example, these runs on 03/30 have a std <0.03%, so 1.2% is like 40σ.
Image

Note that the code is different from the 1770 steps small run.
Since 5960 steps take too much time, I think you can set the training steps to ~1000 steps to have a decent reproduction.

@xmfan xmfan added the upstream triton Upstream Triton Issue label May 6, 2025
@xmfan
Copy link
Member
xmfan commented May 7, 2025

Profiling the first 20 iterations, we see about a 2% regression per iteration from the bump. Half of it seems to come from triton_tem_fused_zeros_7 (flex attention). No changes to kernel code: https://www.diffchecker.com/Hmt40fkF/

@pytorch-bot pytorch-bot bot added module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels May 7, 2025
@YouJiacheng
Copy link
Contributor Author
YouJiacheng commented May 8, 2025

oh this matches my speculation:

  1. such a large regression is unlikely caused by pointwise op. and GEMMs are not routed to triton. so only flex attention.
  2. in the latest record, we increased the max window size (and average window size) of flex attention, and I observed the regression is more significant in the latest record than previous records.

(I planned to do a profiling to verify my speculation on Monday but I was a bit busy and forgot to do so...)

@xmfan
Copy link
Member
xmfan commented May 8, 2025

@davidberard98 is digging further into the triton side

@davidberard98
Copy link
Contributor

https://gist.github.com/davidberard98/f10db5520c96111254e614b53db9f501

Here are some scripts for separating out the affected kernels.

I'm getting pulled into some other tasks but I'll hopefully return to this in a day or two

@davidberard98
Copy link
Contributor
davidberard98 commented May 14, 2025

At the TTIR level, the main difference I see is that some loop-invariant code has been put into the loop in triton 3.3 (whereas in 3.2, the code was pulled outside of the loop).

However, I would imagine that the LICM pass should be able to recognize this and pull the loop-invariant code out of the loop.

Here's the TTIR (both files in the same gist): https://gist.github.com/davidberard98/5e77aa6e0206b20acee4a21535fa3ba3

@davidberard98
Copy link
Contributor

It turns out that the individual kernels I linked in https://gist.github.com/davidberard98/f10db5520c96111254e614b53db9f501 don't repro a difference when I test them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: flex attention module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module upstream triton Upstream Triton Issue
Projects
None yet
Development

No branches or pull requests

5 participants
0