10000 [Inductor] Add fused_attention pattern matcher with additional clone … · pytorch/pytorch@03e7f0b · GitHub
[go: up one dir, main page]

Skip to content

Commit 03e7f0b

Browse files
authored
[Inductor] Add fused_attention pattern matcher with additional clone (#108141) (#108327)
A previous PR #106274 decomposes `aten.dropout` and would create a `clone()` when `eval()` or `p=0`. This makes many SDPA-related models fail to match fused_attention pattern matchers. This PR adds new fused_attention pattern matchers with an additional clone to re-enable the SDPA op matching. Pull Request resolved: #108141 Approved by: https://github.com/jgong5, https://github.com/eellison
1 parent c0e7239 commit 03e7f0b

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

test/inductor/test_fused_attention.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,58 @@ def dot_prod_attention(
443443

444444
self._check_common(dot_prod_attention, contains=False, has_dropout=True)
445445

446+
@skipIfRocm
447+
def _test_sdpa_rewriter_13(self):
448+
def dot_prod_attention(
449+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
450+
) -> torch.Tensor:
451+
"""Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
452+
return (
453+
torch.matmul(query, key.transpose(-2, -1))
454+
.div(math.sqrt(key.shape[-1]))
455+
.softmax(dim=-1)
456+
.clone()
457+
.matmul(value)
458+
)
459+
460+
self._check_common(dot_prod_attention)
461+
self._check_common(checkpoint_wrapper(dot_prod_attention))
462+
463+
@skipIfRocm
464+
def _test_sdpa_rewriter_14(self):
465+
def dot_prod_attention(
466+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
467+
) -> torch.Tensor:
468+
return (
469+
torch.matmul(query, key.transpose(-2, -1))
470+
.mul(1.0 / math.sqrt(key.shape[-1]))
471+
.softmax(dim=-1)
472+
.clone()
473+
.matmul(value)
474+
)
475+
476+
self._check_common(dot_prod_attention)
477+
self._check_common(checkpoint_wrapper(dot_prod_attention))
478+
479+
@skipIfRocm
480+
def _test_sdpa_rewriter_15(self):
481+
def dot_prod_attention(
482+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
483+
) -> torch.Tensor:
484+
"""Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
485+
q = query.transpose(1, 2)
486+
k = key.transpose(1, 2)
487+
v = value.transpose(1, 2)
488+
return (
489+
torch.matmul(q, k.transpose(-2, -1))
490+
.div(math.sqrt(key.shape[-1]))
491+
.softmax(dim=-1)
492+
.clone()
493+
.matmul(v)
494+
)
495+
496+
self._check_common(dot_prod_attention)
497+
446498

447499
if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_SDPA:
448500

@@ -493,6 +545,15 @@ class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate):
493545
test_sdpa_rewriter_12_cuda = (
494546
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
495547
)
548+
test_sdpa_rewriter_13_cuda = (
549+
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13
550+
)
551+
test_sdpa_rewriter_14_cuda = (
552+
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
553+
)
554+
test_sdpa_rewriter_15_cuda = (
555+
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
556+
)
496557

497558

498559
if HAS_CPU:
@@ -517,6 +578,15 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
517578
test_sdpa_rewriter_12_cpu = (
518579
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
519580
)
581+
test_sdpa_rewriter_13_cpu = (
582+
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13
583+
)
584+
test_sdpa_rewriter_14_cpu = (
585+
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
586+
)
587+
test_sdpa_rewriter_15_cpu = (
588+
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
589+
)
520590

521591

522592
if __name__ == "__main__":

torch/_inductor/fx_passes/fuse_attention.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,81 @@ def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
303303
)
304304

305305

306+
def _sfdp_pattern_13(query, key, value, inv_scale):
307+
# dropout would create a clone() if eval() or p = 0
308+
return (
309+
torch.matmul(query, key.transpose(-2, -1))
310+
.div(inv_scale)
311+
.softmax(dim=-1)
312+
.clone()
313+
.matmul(value)
314+
)
315+
316+
317+
def _sfdp_replacement_13(query, key, value, inv_scale):
318+
counters["inductor"]["fuse_attention"] += 1
319+
return aten.scaled_dot_product_attention(
320+
query.contiguous(),
321+
key.contiguous(),
322+
value.contiguous(),
323+
attn_mask=None,
324+
dropout_p=0.0,
325+
is_causal=False,
326+
scale=1.0 / inv_scale,
327+
)
328+
329+
330+
def _sfdp_pattern_14(query, key, value, scale_factor):
331+
# dropout would create a clone() if eval() or p = 0
332+
return (
333+
torch.matmul(query, key.transpose(-2, -1))
334+
.mul(scale_factor)
335+
.softmax(dim=-1)
336+
.clone()
337+
.matmul(value)
338+
)
339+
340+
341+
def _sfdp_replacement_14(query, key, value, scale_factor):
342+
counters["inductor"]["fuse_attention"] += 1
343+
return aten.scaled_dot_product_attention(
344+
query.contiguous(),
345+
key.contiguous(),
346+
value.contiguous(),
347+
attn_mask=None,
348+
dropout_p=0.0,
349+
is_causal=False,
350+
scale=scale_factor,
351+
)
352+
353+
354+
def _sfdp_pattern_15(query, key, value, inv_scale):
355+
# dropout would create a clone() if eval() or p = 0
356+
q = query.permute(0, 2, 1, 3)
357+
k = key.permute(0, 2, 1, 3)
358+
v = value.permute(0, 2, 1, 3)
359+
return (
360+
torch.matmul(q, k.transpose(-2, -1))
361+
.div(inv_scale)
362+
.softmax(dim=-1)
363+
.clone()
364+
.matmul(v)
365+
)
366+
367+
368+
def _sfdp_replacement_15(query, key, value, inv_scale):
369+
counters["inductor"]["fuse_attention"] += 1
370+
return aten.scaled_dot_product_attention(
371+
query.transpose(1, 2),
372+
key.transpose(1, 2),
373+
value.transpose(1, 2),
374+
attn_mask=None,
375+
dropout_p=0.0,
376+
is_causal=False,
377+
scale=1.0 / inv_scale,
378+
)
379+
380+
306381
def _sfdp_params_check(match):
307382
assert all(k in match.kwargs for k in ("query", "key", "value"))
308383
query = match.kwargs["query"].meta["val"]
@@ -450,6 +525,27 @@ def _sfdp_init():
450525
d,
451526
_sfdp_scale_factor_check(aten.div.Tensor),
452527
),
528+
(
529+
_sfdp_pattern_13,
530+
_sfdp_replacement_13,
531+
[g(), g(), g(), c()],
532+
{},
533+
_sfdp_scale_factor_check(aten.div.Tensor),
534+
),
535+
(
536+
_sfdp_pattern_14,
537+
_sfdp_replacement_14,
538+
[g(), g(), g(), c()],
539+
{},
540+
_sfdp_scale_factor_check(aten.mul.Tensor),
541+
),
542+
(
543+
_sfdp_pattern_15,
544+
_sfdp_replacement_15,
545+
[g(), g(), g(), c()],
546+
{},
547+
_sfdp_scale_factor_check(aten.div.Tensor),
548+
),
453549
]:
454550
args = [*args, *workaround.values()]
455551
register_replacement(

0 commit comments

Comments
 (0)
0