8000 [ATen][CUDA] Implement 128 bit vectorization v2 by Aidyn-A · Pull Request #145746 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ATen][CUDA] Implement 128 bit vectorization v2 #145746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

Aidyn-A
Copy link
Collaborator
@Aidyn-A Aidyn-A commented Jan 27, 2025

This is a re-base PR to my previous one #141959.

Description from the original PR:

This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.

The benchmark code used
import time
import torch
from torch.profiler import profile, ProfilerActivity


def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False):
    device = torch.device("cuda")

    shapes = []
    for p in range(24, 30):
        shape = 1<<p
        shapes.append(shape)

    for shape in shapes:
        for _ in range(6):
            x = torch.randn(shape, device=device, dtype=dtype)
            y = function(x)

        if print_profile:
            x = torch.randn(shape, device=device, dtype=dtype)
            with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
                y = function(x)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

        x = torch.randn(shape, device=device, dtype=dtype)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        for _ in range(6):
            y = function(x)
        torch.cuda.synchronize()
        t2 = time.perf_counter()
        perf_time = (t2 - t1) / 6

        print(f"{function.__name__}, {dtype}, {shape}, {perf_time}")
        if check_numerics:
            x_cpu = x.cpu()
            y_cpu = function(x_cpu).cuda()
            try:
                torch.testing.assert_allclose(y_cpu, y)
            except AssertionError as error:
                print("An exception occurred:", error)


def main():
    ops = [
            torch.relu,
            torch.sigmoid,
            torch.tanh,
            torch.nn.functional.gelu,
            torch.sin,
            torch.exp,
    ]

    dtypes = [
            torch.float16,
            torch.bfloat16,
            torch.float32,
    ]

    for op in ops:
        for dtype in dtypes:
            benchmark(op, dtype=dtype)
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
Results
op dtype size time after time before % improvement
relu torch.float16 33554432 4.84E-05 5.06E-05 4.66296539127052
relu torch.float16 67108864 9.22E-05 9.64E-05 4.56491432752297
relu torch.float16 134217728 0.000180343495837102 0.000187981834945579 4.23543919508829
relu torch.float16 268435456 0.000355071155354381 0.000370856161074092 4.44558942107169
relu torch.float16 536870912 0.000704489842367669 0.000736006341564159 4.47366268483987
relu torch.bfloat16 16777216 3.03E-05 3.04E-05 0.166504085842689
relu torch.bfloat16 33554432 4.89E-05 5.06E-05 3.45848238875716
relu torch.bfloat16 67108864 9.32E-05 9.65E-05 3.56122651631445
relu torch.bfloat16 134217728 0.000180805509444326 0.000187998676362137 3.97840029317567
relu torch.bfloat16 268435456 0.000356242332297067 0.000371279485989362 4.22104627356745
relu torch.bfloat16 536870912 0.000708114336399982 0.000736773828975856 4.04729732229083
relu torch.float32 16777216 5.61E-05 5.61E-05 0.0442587268354941
relu torch.float32 33554432 9.33E-05 9.30E-05 -0.259070913799022
relu torch.float32 67108864 0.000181321326332788 0.000181289506144822 -0.0175490597877115
relu torch.float32 134217728 0.000356896334172537 0.000356570177245885 -0.0913870206618981
relu torch.float32 268435456 0.000709421835684528 0.000707465515006334 -0.275762681635911
relu torch.float32 536870912 0.00141372415237129 0.00141036518228551 -0.237597276678471
sigmoid torch.float16 16777216 3.10E-05 3.16E-05 2.10012593866895
sigmoid torch.float16 33554432 4.91E-05 5.23E-05 6.37710600666122
sigmoid torch.float16 67108864 9.30E-05 0.000100057009452333 7.61866144555331
sigmoid torch.float16 134217728 0.000180928347011407 0.000194982004662355 7.76752669390248
sigmoid torch.float16 268435456 0.000355658994521946 0.00038468533117945 8.16128288742412
sigmoid torch.float16 536870912 0.000705982849467546 0.000764021339515845 8.22094900634937
sigmoid torch.bfloat16 16777216 3.08E-05 3.17E-05 2.90965915673149
sigmoid torch.bfloat16 33554432 4.87E-05 5.24E-05 7.63503884668234
sigmoid torch.bfloat16 67108864 9.33E-05 0.000100019678939134 7.21238137428013
sigmoid torch.bfloat16 134217728 0.000180786165098349 0.000194868014659733 7.78922964250206
sigmoid torch.bfloat16 268435456 0.000355564659306159 0.000384909333661199 8.25297835063321
sigmoid torch.bfloat16 536870912 0.000705831005082776 0.000764102345177283 8.2557070566308
sigmoid torch.float32 16777216 4.93E-05 5.65E-05 14.5314136197766
sigmoid torch.float32 33554432 9.32E-05 9.31E-05 -0.120169865610833
sigmoid torch.float32 67108864 0.000181328505277634 0.000180455681402236 -0.481349512069855
sigmoid torch.float32 134217728 0.000357362829769651 0.000356093340087682 -0.35523831137877
sigmoid torch.float32 268435456 0.000708921831877281 0.000707052337626616 -0.263709504574663
sigmoid torch.float32 536870912 0.00141358317341656 0.0014090768333214 -0.318788464654745
tanh torch.float16 16777216 3.03E-05 3.03E-05 -0.0912564658661808
tanh torch.float16 33554432 4.90E-05 5.07E-05 3.46644442974484
tanh torch.float16 67108864 9.30E-05 9.68E-05 3.99871369815531
tanh torch.float16 134217728 0.00018052199933057 0.000188717152923346 4.53969799978138
tanh torch.float16 268435456 0.000355684508879979 0.000373026006855071 4.8755280430115
tanh torch.float16 536870912 0.000706660988119741 0.000740105014604827 4.73268328765002
tanh torch.bfloat16 16777216 2.99E-05 3.03E-05 1.21049563135981
tanh torch.bfloat16 33554432 4.89E-05 5.06E-05 3.48836101041744
tanh torch.bfloat16 67108864 9.28E-05 9.69E-05 4.39944918036626
tanh torch.bfloat16 134217728 0.000180710999605556 0.000189167990659674 4.67984299382829
tanh torch.bfloat16 268435456 0.000356062994493792 0.000372666652159144 4.66312363882606
tanh torch.bfloat16 536870912 0.000707100164921333 0.000740134331863374 4.67178040408393
tanh torch.float32 16777216 5.61E-05 5.64E-05 0.439595755746353
tanh torch.float32 33554432 9.31E-05 9.31E-05 0.00287633090228212
tanh torch.float32 67108864 0.000181465332085888 0.000180895323865116 -0.31411411437098
tanh torch.float32 134217728 0.000356963835656643 0.000356073161431899 -0.249513854283251
tanh torch.float32 268435456 0.000709201170442005 0.00070707315656667 -0.300057862849997
tanh torch.float32 536870912 0.00141367283261692 0.00141030051357423 -0.238550176877922
gelu torch.float16 16777216 2.73E-05 3.17E-05 15.921079070745
gelu torch.float16 33554432 5.06E-05 5.55E-05 9.76345374333098
gelu torch.float16 67108864 9.65E-05 0.000106600326641152 10.4308039074712
gelu torch.float16 134217728 0.000187776672343413 0.000208565829476962 11.0712139447915
gelu torch.float16 268435456 0.000370216167842348 0.000412251994324227 11.3544005187205
gelu torch.float16 536870912 0.000737301345604161 0.000819394170927505 11.1342296895002
gelu torch.bfloat16 16777216 3.02E-05 3.08E-05 1.78405479367653
gelu torch.bfloat16 33554432 5.13E-05 5.69E-05 10.9929393318302
gelu torch.bfloat16 67108864 9.76E-05 0.00010968199543034 12.3420807512356
gelu torch.bfloat16 134217728 0.000189661824454864 0.000214487663470209 13.0895287371091
gelu torch.bfloat16 268435456 0.000374197009174774 0.000423670164309442 13.2211519391275
gelu torch.bfloat16 536870912 0.000743675006863972 0.000842577001700799 13.299088166737
gelu torch.float32 16777216 5.06E-05 5.04E-05 -0.413385894716413
gelu torch.float32 33554432 9.31E-05 9.32E-05 0.134157041722546
gelu torch.float32 67108864 0.000181480175039421 0.000180836669945469 -0.354586992112075
gelu torch.float32 134217728 0.000356874331676712 0.000356305002545317 -0.159532104402047
gelu torch.float32 268435456 0.000708909006789327 0.000706991491218408 -0.270488250615287
gelu torch.float32 536870912 0.00141321367118508 0.00140937082081412 -0.271922813181618
sin torch.float16 16777216 3.04E-05 3.11E-05 2.21834939018859
sin torch.float16 33554432 4.85E-05 5.23E-05 7.72165512511596
sin torch.float16 67108864 9.31E-05 9.98E-05 7.24947099480072
sin torch.float16 134217728 0.000180371008658161 0.000194791161144773 7.99471744039613
sin torch.float16 268435456 0.000355454161763191 0.000384903668115536 8.28503630574026
sin torch.float16 536870912 0.000705183832906187 0.000764360166310022 8.39161799270973
sin torch.bfloat16 16777216 3.11E-05 3.10E-05 -0.257677954940036
sin torch.bfloat16 33554432 4.89E-05 5.24E-05 7.34808420323539
sin torch.bfloat16 67108864 9.26E-05 0.000100248667877167 8.22347488801205
sin torch.bfloat16 134217728 0.000180674154156198 0.00019567032965521 8.30012215584937
sin torch.bfloat16 268435456 0.000355360486234228 0.000386023331278314 8.62865913118873
sin torch.bfloat16 536870912 0.00070483615854755 0.000766805159704139 8.79197248964745
sin torch.float32 16777216 5.67E-05 5.64E-05 -0.441348534920039
sin torch.float32 33554432 9.34E-05 9.30E-05 -0.496458540364117
sin torch.float32 67108864 0.000181706990891447 0.000180556671693921 -0.633062708199702
sin torch.float32 134217728 0.000356894995396336 0.000356046327700218 -0.237791985616354
sin torch.float32 268435456 0.000708777321657787 0.000707602652255446 -0.165731798471427
sin torch.float32 536870912 0.00141263716310884 0.00140912582476934 -0.248566187496451
exp torch.float16 16777216 3.00E-05 3.04E-05 1.40099098901014
exp torch.float16 33554432 4.86E-05 5.03E-05 3.44611943643906
exp torch.float16 67108864 9.37E-05 9.55E-05 1.96412400380129
exp torch.float16 134217728 0.000180913504057874 0.000187193179347863 3.47109262113439
exp torch.float16 268435456 0.00035607748820136 0.000369079003576189 3.65131630210701
exp torch.float16 536870912 0.000707551507124056 0.000732363162872692 3.50669251620789
exp torch.bfloat16 16777216 2.98E-05 3.04E-05 1.74345594341654
exp torch.bfloat16 33554432 4.88E-05 5.04E-05 3.40217856534821
exp torch.bfloat16 67108864 9.32E-05 9.62E-05 3.29219958210226
exp torch.bfloat16 134217728 0.000180999826019009 0.000187239318620414 3.44723679499521
exp torch.bfloat16 268435456 0.000355944503098726 0.000369370992605885 3.77207384585864
exp torch.bfloat16 536870912 0.000707135167128096 0.000733066000975668 3.66702648277075
exp torch.float32 16777216 4.89E-05 5.63E-05 15.1245314346532
exp torch.float32 33554432 9.34E-05 9.31E-05 -0.259945454477446
exp torch.float32 67108864 0.000181152504713585 0.000180474346658836 -0.374357536939058
exp torch.float32 134217728 0.000356771342922002 0.000355627329554409 -0.3206573034212
exp torch.float32 268435456 0.000708404501589636 0.00070713268360123 -0.179532736671163
exp torch.float32 536870912 0.00141283582585553 0.00140944866385932 -0.23974208002295

cc @msaroufim @ptrblck @eqy @manuelcandales @SherlockNoMad @angelayi

@Aidyn-A Aidyn-A added module: cuda Related to torch.cuda, and CUDA support in general topic: not user facing topic category module: core aten Related to change to the Core ATen opset labels Jan 27, 2025
@Aidyn-A Aidyn-A requested review from eqy and ngimel January 27, 2025 18:40
@Aidyn-A Aidyn-A self-assigned this Jan 27, 2025
@Aidyn-A Aidyn-A requested a review from syed-ahmed as a code owner January 27, 2025 18:40
Copy link
pytorch-bot bot commented Jan 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145746

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 575b250 with merge base 0674ab7 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Skylion007
Copy link
Collaborator

Did we confirm not regresion on A100 per comments on the prev PR?

@Aidyn-A
Copy link
Collaborator Author
Aidyn-A commented Jan 27, 2025

Did we confirm not regresion on A100 per comments on the prev PR?

Yes, the regression is due to a compiler (NVCC) bug. Moreover, I discovered the bug is present on H100 as well. I have omitted vec8/vec16 for 1-byte data on all archs to workaround the bug.

@Skylion007
Copy link
Collaborator
Skylion007 commented Jan 27, 2025

Did we confirm not regresion on A100 per comments on the prev PR?

Yes, the regression is due to a compiler (NVCC) bug. Moreover, I discovered the bug is present on H100 as well. I have omitted vec8/vec16 for 1-byte data on all archs to workaround the bug.

What versions of NVCC are affected? We are potentially planning dropping old NVCC support (11.8/12.4)

@Aidyn-A
Copy link
Collaborator Author
Aidyn-A commented Jan 27, 2025

What versions of NVCC are affected? We are potentially planning dropping old NVCC support (11.8/12.4)

I have not checked 11.8, but 12.4-12.6 are affected. 12.8 is doing fine if nvvm-latest flag was enforced. Though, I am not certain if it is generally safe to force nvvm-latest on sm_80-90.

if constexpr (io_sizes == 1) {
return 16;
} else {
return 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is elems_per_thread = 8 allaround better than 4 we mostly used previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed little to no difference. The biggest improvement come from vec8.

uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type));
vec_size = std::min<uint16_t>(vec_size, max_vec_size);
if (sizeof(cpp_type) < 2) {
vec_size = std::min<uint16_t>(vec_size, 4);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you setting max vec size to 4 here for 1 byte datatypes? Is it to workaround that bug? Can you leave a comment then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This is a workaround that bug. I have left a comment that explains it.

@Aidyn-A Aidyn-A added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 29, 2025
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 07:41 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 07:41 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 07:41 Inactive
@Aidyn-A
Copy link
Collaborator Author
Aidyn-A commented Jan 29, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@Aidyn-A
Copy link
Collaborator Author
Aidyn-A commented Jan 31, 2025

The lint failures are unrelated

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: Lint / lintrunner-noclang / linux-job, Lint / Test tools / linux-job

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

mori360 pushed a commit to mori360/pytorch that referenced this pull request Feb 6, 2025
This is a re-base PR to my previous one pytorch#141959.

Description from the original PR:

This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.

<details>

<summary>The benchmark code used </summary>

```Python

import time
import torch
from torch.profiler import profile, ProfilerActivity

def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False):
    device = torch.device("cuda")

    shapes = []
    for p in range(24, 30):
        shape = 1<<p
        shapes.append(shape)

    for shape in shapes:
        for _ in range(6):
            x = torch.randn(shape, device=device, dtype=dtype)
            y = function(x)

        if print_profile:
            x = torch.randn(shape, device=device, dtype=dtype)
            with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
                y = function(x)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

        x = torch.randn(shape, device=device, dtype=dtype)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        for _ in range(6):
            y = function(x)
        torch.cuda.synchronize()
        t2 = time.perf_counter()
        perf_time = (t2 - t1) / 6

        print(f"{function.__name__}, {dtype}, {shape}, {perf_time}")
        if check_numerics:
            x_cpu = x.cpu()
            y_cpu = function(x_cpu).cuda()
            try:
                torch.testing.assert_allclose(y_cpu, y)
            except AssertionError as error:
                print("An exception occurred:", error)

def main():
    ops = [
            torch.relu,
            torch.sigmoid,
            torch.tanh,
            torch.nn.functional.gelu,
            torch.sin,
            torch.exp,
    ]

    dtypes = [
            torch.float16,
            torch.bfloat16,
            torch.float32,
    ]

    for op in ops:
        for dtype in dtypes:
            benchmark(op, dtype=dtype)
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
```

</details>

<details>

<summary> Results </summary>

| op | dtype | size | time after | time before | % improvement |
| ---- | ---- | ---- | ---- | ---- | ---- |
| relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 |
| relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 |
| relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 |
| relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 |
| relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 |
| relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 |
| relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 |
| relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 |
| relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 |
| relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 |
| relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 |
| relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 |
| relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 |
| relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 |
| relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 |
| relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 |
| relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 |
| sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 |
| sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 |
| sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 |
| sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 |
| sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 |
| sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 |
| sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 |
| sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 |
| sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 |
| sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 |
| sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 |
| sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 |
| sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 |
| sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 |
| sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 |
| sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 |
| sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 |
| sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 |
| tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 |
| tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 |
| tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 |
| tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 |
| tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 |
| tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 |
| tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 |
| tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 |
| tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 |
| tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 |
| tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 |
| tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 |
| tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 |
| tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 |
| tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 |
| tanh | torch.float32 |
9E88
 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 |
| tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 |
| tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 |
| gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 |
| gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 |
| gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 |
| gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 |
| gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 |
| gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 |
| gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 |
| gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 |
| gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 |
| gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 |
| gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 |
| gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 |
| gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 |
| gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 |
| gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 |
| gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 |
| gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 |
| gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 |
| sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 |
| sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 |
| sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 |
| sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 |
| sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 |
| sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 |
| sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 |
| sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 |
| sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 |
| sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 |
| sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 |
| sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 |
| sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 |
| sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 |
| sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 |
| sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 |
| sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 |
| sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 |
| exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 |
| exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 |
| exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 |
| exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 |
| exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 |
| exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 |
| exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 |
| exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 |
| exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 |
| exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 |
| exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 |
| exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 |
| exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 |
| exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 |
| exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 |
| exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 |
| exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 |
| exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 |

</details>

Pull Request resolved: pytorch#145746
Approved by: https://github.com/eqy, https://github.com/ngimel

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
atalman added a commit to atalman/pytorch that referenced this pull request Feb 19, 2025
Copy link
Contributor
@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it should be guarded with __CUDA_ARCH__ >= 90 to avoid compile time increases for older architectures

@ngimel
Copy link
Collaborator
ngimel commented Feb 25, 2025

I think it helps perf even on H100

@malfet
Copy link
Contributor
malfet commented Feb 25, 2025

I think it helps perf even on H100

I've made off by one error in my calculations anyway, CUDA_ARCH > 10 are Fermi and above :)

@Aidyn-A
Copy link
Collaborator Author
Aidyn-A commented Feb 28, 2025

IMO it should be guarded with __CUDA_ARCH__ >= 90 to avoid compile time increases for older architectures

I apologize for the delay, but it seems like I will need to do some refactoring to achieve it, because __CUDA_ARCH__ is available in kernels and device functions only.

atalman added a commit to atalman/pytorch that referenced this pull request Apr 4, 2025
pytorchmergebot pushed a commit that referenced this pull request Apr 8, 2025
By addressing a feedback requested at #145746
Pull Request resolved: #150705
Approved by: https://github.com/atalman
pytorchbot pushed a commit that referenced this pull request Apr 8, 2025
By addressing a feedback requested at #145746
Pull Request resolved: #150705
Approved by: https://github.com/atalman

(cherry picked from commit 5228986)
atalman pushed a commit that referenced this pull request Apr 8, 2025
[CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705)

By addressing a feedback requested at #145746
Pull Request resolved: #150705
Approved by: https://github.com/atalman

(cherry picked from commit 5228986)

Co-authored-by: Nikita Shulga <nshulga@meta.com>
malfet added a commit that referenced this pull request Apr 9, 2025
timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
pytorchmergebot pushed a commit that referenced this pull request May 2, 2025
Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: #148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman
pytorchbot pushed a commit that referenced this pull request May 6, 2025
Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: #148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman

(cherry picked from commit 72337bd)
malfet pushed a commit that referenced this pull request May 7, 2025
[ATen][CUDA] Optimize 128 bit vectorization (#148320)

Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: #148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman

(cherry picked from commit 72337bd)

Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs 5C53 on your pull request Merged module: core aten Related to change to the Core ATen opset module: cuda Related to torch.cuda, and CUDA support in general open source Reverted topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants
0