-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Open
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton IssueUpstream Triton Issue
Description
🐛 Describe the bug
Compiled (welford) is the default execution, the other plot is with welford reductions turned off.
import torch
from torch import nn
import torch.nn.functional as F
def bench(f, name=None, iters=1000, warmup=5, display=True, profile=False):
import time
from triton.testing import do_bench
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
us_per_iter = do_bench(lambda: f())*1000
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter:.3f}us"
if display:
print(res)
return res
for D in range(256, 4096 + 1, 256):
inp = torch.randn(2048*128, D, dtype=torch.bfloat16, device='cuda')
mod = nn.LayerNorm([D], dtype=torch.bfloat16, device='cuda')
comp_mod = torch.compile(mod, dynamic=False)
print(D)
bench(lambda: mod(inp))
bench(lambda: comp_mod(inp))
cc: @peterbell10, @eellison, @shunting314
Versions
N/A
cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @bertmaher @int3 @davidberard98 @nmacchioni @embg @msaroufim @bdhirsh @zou3519 @aakhundov
Metadata
Metadata
Assignees
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton IssueUpstream Triton Issue