8000 SDPA: CUDNN backend error w/ q_seq_len = 1 · Issue #138529 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

SDPA: CUDNN backend error w/ q_seq_len = 1 #138529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
drisspg opened this issue Oct 22, 2024 · 2 comments
Open

SDPA: CUDNN backend error w/ q_seq_len = 1 #138529

drisspg opened this issue Oct 22, 2024 · 2 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@drisspg
Copy link
Contributor
drisspg commented Oct 22, 2024

Summary

Repro script

import torch
import torch.nn as nn
import torch.nn.functional as F


q = torch.randn(1, 16, 1, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(1, 16, 2**16, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(1, 16, 2**16, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True)


from torch.nn.attention import sdpa_kernel, SDPBackend    

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)
    out.backward(torch.ones_like(out))

Error:

/home/drisspg/meta/pytorch/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported. Materializing a contiguous tensor which will increase memory usage... (Triggered internally at /home/drisspg/meta/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:664.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/home/drisspg/meta/scripts/sdpa/repro_gqa.py", line 15, in <module>
    out.sum().backward()
  File "/home/drisspg/meta/pytorch/torch/_tensor.py", line 624, in backward
    torch.autograd.backward(
  File "/home/drisspg/meta/pytorch/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/drisspg/meta/pytorch/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph.

cc @ptrblck @msaroufim @eqy @mikaylagawarecki

@drisspg drisspg added module: cuda Related to torch.cuda, and CUDA support in general module: multi-headed-attention labels Oct 22, 2024
@drisspg
Copy link
Contributor Author
drisspg commented Oct 22, 2024

FrontEnd Log: frontendlog.txt

BackendLog: backendlog.txt

@eqy
Copy link
Collaborator
eqy commented Oct 22, 2024

Looks like it's not happy with sequence length 1, rather than the mismatched s_q vs. s_kv, forwarding to cuDNN...

@drisspg drisspg changed the title CuDNN Backend cuDNN Frontend error SDPA: CUDNN backend error w/ q_seq_len = 1 Oct 22, 2024
drisspg added a commit that referenced this issue Oct 22, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend    

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). 


Cc atalman

cc mikaylagawarecki

[ghstack-poisoned]
pytorchbot pushed a commit that referenced this issue Oct 22, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet

(cherry picked from commit 9a9a0ab)
kit1980 pushed a commit that referenced this issue Oct 22, 2024
[SDPA-CUDNN] Make CuDNN Attention Opt in (#138522)

# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet

(cherry picked from commit 9a9a0ab)

Co-authored-by: drisspg <drisspguessous@gmail.com>
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 23, 2024
SamGinzburg pushed a commit that referenced this issue Oct 28, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet
pytorchmergebot pushed a commit that referenced this issue Oct 29, 2024
Forwarded #138529 to the cuDNN team but for now but we want to avoid dispatching to unsupported cases

Pull Request resolved: #138531
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this issue Oct 29, 2024
)

Forwarded pytorch#138529 to the cuDNN team but for now but we want to avoid dispatching to unsupported cases

Pull Request resolved: pytorch#138531
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this issue Nov 5, 2024
)

Forwarded pytorch#138529 to the cuDNN team but for now but we want to avoid dispatching to unsupported cases

Pull Request resolved: pytorch#138531
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
@drisspg drisspg added module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion and removed module: multi-headed-attention labels Nov 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0