8000 [not for review] benchmark script by bobrenjc93 · Pull Request #152596 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[not for review] benchmark script #152596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
48 changes: 48 additions & 0 deletions mg1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from triton.testing import do_bench

import torch
from torch._inductor.utils import fresh_inductor_cache


@torch.compile(
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
dynamic=False,
)
def inductor_matmul(m, a, b):
torch._check(a.shape[0] == b.shape[1])
# passing in m to have different compiled regions
return (m, torch.mm(a, b))


for m in [2, 4, 8, 16]:
with fresh_inductor_cache():
print(f"m={m}")
k = 1280
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
static_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
dynamic_b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16)
static_b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16)
torch._dynamo.decorators.mark_dynamic(
dynamic_a, # s0
0,
backend_specializations=[
# hint, specialization
(2, lambda x0: x0 == 2),
(4, lambda x0: x0 == 4),
(8, lambda x0: x0 == 8),
(16, lambda x0: x0 == 16),
],
)
torch._dynamo.decorators.mark_dynamic(
dynamic_b,
1,
)
inductor_matmul(m, static_a, static_b)
ms = do_bench(lambda: inductor_matmul(m, static_a, static_b))
print("static ms taken:", ms)
inductor_matmul(m, dynamic_a, dynamic_b)
ms = do_bench(lambda: inductor_matmul(m, dynamic_a, dynamic_b))
print("dynamic ms taken:", ms)
Loading
0