-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[ATen][CUDA] Implement 128 bit vectorization v2 #145746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145746
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 575b250 with merge base 0674ab7 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Did we confirm not regresion on A100 per comments on the prev PR? |
Yes, the regression is due to a compiler (NVCC) bug. Moreover, I discovered the bug is present on H100 as well. I have omitted vec8/vec16 for 1-byte data on all archs to workaround the bug. |
What versions of NVCC are affected? We are potentially planning dropping old NVCC support (11.8/12.4) |
I have not checked 11.8, but 12.4-12.6 are affected. 12.8 is doing fine if |
if constexpr (io_sizes == 1) { | ||
return 16; | ||
} else { | ||
return 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is elems_per_thread = 8 allaround better than 4 we mostly used previously?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I observed little to no difference. The biggest improvement come from vec8
.
uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type)); | ||
vec_size = std::min<uint16_t>(vec_size, max_vec_size); | ||
if (sizeof(cpp_type) < 2) { | ||
vec_size = std::min<uint16_t>(vec_size, 4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you setting max vec size to 4 here for 1 byte datatypes? Is it to workaround that bug? Can you leave a comment then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. This is a workaround that bug. I have left a comment that explains it.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The lint failures are unrelated @pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 2 checks: Lint / lintrunner-noclang / linux-job, Lint / Test tools / linux-job Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This is a re-base PR to my previous one pytorch#141959. Description from the original PR: This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100. <details> <summary>The benchmark code used </summary> ```Python import time import torch from torch.profiler import profile, ProfilerActivity def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False): device = torch.device("cuda") shapes = [] for p in range(24, 30): shape = 1<<p shapes.append(shape) for shape in shapes: for _ in range(6): x = torch.randn(shape, device=device, dtype=dtype) y = function(x) if print_profile: x = torch.randn(shape, device=device, dtype=dtype) with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: y = function(x) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) x = torch.randn(shape, device=device, dtype=dtype) torch.cuda.synchronize() t1 = time.perf_counter() for _ in range(6): y = function(x) torch.cuda.synchronize() t2 = time.perf_counter() perf_time = (t2 - t1) / 6 print(f"{function.__name__}, {dtype}, {shape}, {perf_time}") if check_numerics: x_cpu = x.cpu() y_cpu = function(x_cpu).cuda() try: torch.testing.assert_allclose(y_cpu, y) except AssertionError as error: print("An exception occurred:", error) def main(): ops = [ torch.relu, torch.sigmoid, torch.tanh, torch.nn.functional.gelu, torch.sin, torch.exp, ] dtypes = [ torch.float16, torch.bfloat16, torch.float32, ] for op in ops: for dtype in dtypes: benchmark(op, dtype=dtype) torch.cuda.empty_cache() if __name__ == "__main__": main() ``` </details> <details> <summary> Results </summary> | op | dtype | size | time after | time before | % improvement | | ---- | ---- | ---- | ---- | ---- | ---- | | relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 | | relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 | | relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 | | relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 | | relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 | | relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 | | relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 | | relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 | | relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 | | relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 | | relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 | | relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 | | relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 | | relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 | | relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 | | relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 | | relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 | | sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 | | sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 | | sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 | | sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 | | sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 | | sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 | | sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 | | sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 | | sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 | | sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 | | sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 | | sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 | | sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 | | sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 | | sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 | | sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 | | sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 | | sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 | | tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 | | tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 | | tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 | | tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 | | tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 | | tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 | | tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 | | tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 | | tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 | | tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 | | tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 | | tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 | | tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 | | tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 | | tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 | | tanh | torch.float32 | 9E88 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 | | tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 | | tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 | | gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 | | gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 | | gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 | | gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 | | gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 | | gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 | | gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 | | gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 | | gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 | | gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 | | gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 | | gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 | | gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 | | gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 | | gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 | | gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 | | gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 | | gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 | | sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 | | sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 | | sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 | | sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 | | sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 | | sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 | | sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 | | sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 | | sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 | | sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 | | sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 | | sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 | | sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 | | sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 | | sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 | | sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 | | sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 | | sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 | | exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 | | exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 | | exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 | | exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 | | exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 | | exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 | | exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 | | exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 | | exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 | | exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 | | exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 | | exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 | | exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 | | exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 | | exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 | | exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 | | exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 | | exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 | </details> Pull Request resolved: pytorch#145746 Approved by: https://github.com/eqy, https://github.com/ngimel Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
…)" This reverts commit e84bf88.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it should be guarded with __CUDA_ARCH__ >= 90
to avoid compile time increases for older architectures
I think it helps perf even on H100 |
I've made off by one error in my calculations anyway, CUDA_ARCH > 10 are Fermi and above :) |
I apologize for the delay, but it seems like I will need to do some refactoring to achieve it, because |
…)" This reverts commit e84bf88.
By addressing a feedback requested at #145746 Pull Request resolved: #150705 Approved by: https://github.com/atalman
By addressing a feedback requested at #145746 Pull Request resolved: #150705 Approved by: https://github.com/atalman (cherry picked from commit 5228986)
[CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705) By addressing a feedback requested at #145746 Pull Request resolved: #150705 Approved by: https://github.com/atalman (cherry picked from commit 5228986) Co-authored-by: Nikita Shulga <nshulga@meta.com>
…0705) By addressing a feedback requested at pytorch#145746 Pull Request resolved: pytorch#150705 Approved by: https://github.com/atalman
…0705) By addressing a feedback requested at pytorch#145746 Pull Request resolved: pytorch#150705 Approved by: https://github.com/atalman
Fixes #147376. As per request: #145746 (review) This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size. Pull Request resolved: #148320 Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman
Fixes #147376. As per request: #145746 (review) This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size. Pull Request resolved: #148320 Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman (cherry picked from commit 72337bd)
[ATen][CUDA] Optimize 128 bit vectorization (#148320) Fixes #147376. As per request: #145746 (review) This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size. Pull Request resolved: #148320 Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman (cherry picked from commit 72337bd) Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com>
This is a re-base PR to my previous one #141959.
Description from the original PR:
This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.
The benchmark code used
Results
cc @msaroufim @ptrblck @eqy @manuelcandales @SherlockNoMad @angelayi