8000 Basic SDP benchmark harness by cpuhrsch · Pull Request #86729 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Basic SDP benchmark harness #86729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Linting
  • Loading branch information
cpuhrsch committed Oct 11, 2022
commit bcb44f5635779d1252489131bcc1a686647e7672
37 changes: 23 additions & 14 deletions benchmarks/transformer/sdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import itertools
import numpy as np


Expand Down Expand Up @@ -48,6 +49,7 @@ def forward(self, query, key, value, mask):
# Match return signature of nn.MHA
return self.out_proj(attn), None


def build_composite_mha_from_nn_mha(pt):
assert pt._qkv_same_embed_dim
in_proj_weight = pt.in_proj_weight
Expand All @@ -70,17 +72,9 @@ def benchmark_torch_function(iters, f, *args, **kwargs):
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters

if __name__ == "__main__":
iters = 100
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)

batch_size = 64
D = 1024
H = 4
def run_timing(batch_size, D, H, L):
dropout_p = 0.0
max_sequence_length = 128
mask = None

pt = torch.nn.MultiheadAttention(
Expand All @@ -89,16 +83,31 @@ def benchmark_torch_function(iters, f, *args, **kwargs):
npt = pt.eval().half().cuda()
cpt = build_composite_mha_from_nn_mha(npt)

x = torch.randn(batch_size, max_sequence_length, D)
x = torch.randn(batch_size, L, D)
x = x.half().cuda()

pt_output, _ = pt(x, x, x, mask)
cp_output, _ = cpt(x, x, x, mask)

assert torch.allclose(pt_output, cp_output)
# First order sanity check. Not a replacement for rigorous tests.
assert torch.allclose(pt_output, cp_output, atol=1e-3, rtol=1e-3)

with torch.inference_mode():
for H in [1, 2, 4, 8, 16, 32]:
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=True):
with torch.inference_mode():
pt_time = benchmark_torch_function(iters, npt, x, x, x, mask) * 1e3
cp_time = benchmark_torch_function(iters, cpt, x, x, x, mask) * 1e3
print(f"H: {H:2.0f} pt_time: {pt_time:4.2f}ms cp_time: {cp_time:4.2f}ms")
print(f"L: {L:4.0f} H: {H:2.0f} D: {D:4.0f} ", end="")
print(f"pt_time: {pt_time:4.2f}ms ", end="")
print(f"cp_time: {cp_time:4.2f}ms ", end="")
print(f"speedup: {pt_time / cp_time:4.2f}x")


if __name__ == "__main__":
iters = 1
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)

batch_size = 64
for H, L in itertools.product([1, 2, 4, 8, 16, 32], [64, 128, 256]):
run_timing(batch_size, 1024, H, L)
0