8000 welfordreduce slows down forward layernorm in a bunch of cases · Issue #120184 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
welfordreduce slows down forward layernorm in a bunch of cases #120184
@Chillee

Description

@Chillee

🐛 Describe the bug

Compiled (welford) is the default execution, the other plot is with welford reductions turned off.

image

import torch
from torch import nn
import torch.nn.functional as F

def bench(f, name=None, iters=1000, warmup=5, display=True, profile=False):
    import time
    from triton.testing import do_bench

    for _ in range(warmup):
        f()
    if profile:
        with torch.profiler.profile() as prof:
            f()
        prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")

    us_per_iter = do_bench(lambda: f())*1000

    if name is None:
        res = us_per_iter
    else:
        res= f"{name}: {us_per_iter:.3f}us"

    if display:
        print(res)
    return res

for D in range(256, 4096 + 1, 256):
    inp = torch.randn(2048*128, D, dtype=torch.bfloat16, device='cuda')
    mod = nn.LayerNorm([D], dtype=torch.bfloat16, device='cuda')

    comp_mod = torch.compile(mod, dynamic=False)
    print(D)
    bench(lambda: mod(inp))
    bench(lambda: comp_mod(inp))

cc: @peterbell10, @eellison, @shunting314

Versions

N/A

cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @bertmaher @int3 @davidberard98 @nmacchioni @embg @msaroufim @bdhirsh @zou3519 @aakhundov

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton Issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0