-
Notifications
You must be signed in to change notification settings - Fork 24.3k
SDPA: CUDNN backend error w/ q_seq_len = 1 #138529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Labels
module: cuda
Related to torch.cuda, and CUDA support in general
module: sdpa
All things related to torch.nn.functional.scaled_dot_product_attentiion
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
FrontEnd Log: frontendlog.txt BackendLog: backendlog.txt |
Looks like it's not happy with sequence length 1, rather than the mismatched s_q vs. s_kv, forwarding to cuDNN... |
drisspg
added a commit
that referenced
this issue
Oct 22, 2024
# Summary Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5: 1. #138529 2. huggingface/diffusers#9704 3. #138354 In light of the above we are going to make the CuDNN backend Opt-in by default. This can be done easily with the context manager for choosing backends I.e.: ``` Python from torch.nn.attention import sdpa_kernel, SDPBackend with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): out = F.scaled_dot_product_attention(q, k, v) ``` This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). Cc atalman cc mikaylagawarecki [ghstack-poisoned]
pytorchbot
pushed a commit
that referenced
this issue
Oct 22, 2024
# Summary Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5: 1. #138529 2. huggingface/diffusers#9704 3. #138354 In light of the above we are going to make the CuDNN backend Opt-in by default. This can be done easily with the context manager for choosing backends I.e.: ``` Python from torch.nn.attention import sdpa_kernel, SDPBackend with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): out = F.scaled_dot_product_attention(q, k, v) ``` This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). Cc @atalman Pull Request resolved: #138522 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet (cherry picked from commit 9a9a0ab)
kit1980
pushed a commit
that referenced
this issue
Oct 22, 2024
[SDPA-CUDNN] Make CuDNN Attention Opt in (#138522) # Summary Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5: 1. #138529 2. huggingface/diffusers#9704 3. #138354 In light of the above we are going to make the CuDNN backend Opt-in by default. This can be done easily with the context manager for choosing backends I.e.: ``` Python from torch.nn.attention import sdpa_kernel, SDPBackend with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): out = F.scaled_dot_product_attention(q, k, v) ``` This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). Cc @atalman Pull Request resolved: #138522 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet (cherry picked from commit 9a9a0ab) Co-authored-by: drisspg <drisspguessous@gmail.com>
SamGinzburg
pushed a commit
that referenced
this issue
Oct 28, 2024
# Summary Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5: 1. #138529 2. huggingface/diffusers#9704 3. #138354 In light of the above we are going to make the CuDNN backend Opt-in by default. This can be done easily with the context manager for choosing backends I.e.: ``` Python from torch.nn.attention import sdpa_kernel, SDPBackend with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): out = F.scaled_dot_product_attention(q, k, v) ``` This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). Cc @atalman Pull Request resolved: #138522 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet
pytorchmergebot
pushed a commit
that referenced
this issue
Oct 29, 2024
Forwarded #138529 to the cuDNN team but for now but we want to avoid dispatching to unsupported cases Pull Request resolved: #138531 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
rahulsingh-intel
pushed a commit
to rahulsingh-intel/pytorch
that referenced
this issue
Oct 29, 2024
) Forwarded pytorch#138529 to the cuDNN team but for now but we want to avoid dispatching to unsupported cases Pull Request resolved: pytorch#138531 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
rahulsingh-intel
pushed a commit
to rahulsingh-intel/pytorch
that referenced
this issue
Nov 5, 2024
) Forwarded pytorch#138529 to the cuDNN team but for now but we want to avoid dispatching to unsupported cases Pull Request resolved: pytorch#138531 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: cuda
Related to torch.cuda, and CUDA support in general
module: sdpa
All things related to torch.nn.functional.scaled_dot_product_attentiion
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
Summary
Repro script
Error:
cc @ptrblck @msaroufim @eqy @mikaylagawarecki
The text was updated successfully, but these errors were encountered: