10000 [ROCm] CK Flash Attention Backend by xw285cornell · Pull Request #143695 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ROCm] CK Flash Attention Backend #143695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 39 commits into from
Closed

Conversation

xw285cornell
Copy link
Contributor
@xw285cornell xw285cornell commented Dec 21, 2024

Replace #138947 for re-import.

Replaces ROCm#1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Copy link
pytorch-bot bot commented Dec 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143695

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 9f5531f with merge base bb5e439 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm module: dynamo module: rocm AMD GPU support for Pytorch release notes: distributed (c10d) release notes category labels Dec 21, 2024
@facebook-github-bot
Copy link
Contributor

@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 21, 2024
@xw285cornell xw285cornell added the topic: not user facing topic category label Dec 21, 2024
@xw285cornell
Copy link
Contributor Author

@jithunnair-amd cuda build keeps timing out, do you know what's going on?

@LunNova

This comment was marked as outdated.

@LunNova
Copy link
LunNova commented Dec 22, 2024

128 thread build on a ROCm 6.3 stack on an EPYC Milan system that is otherwise idle:

USE_CK_FLASH_ATTENTION=1 build for PYTORCH_ROCM_ARCH=gfx908;gfx90a took ~42 mins
unset USE_CK_FLASH_ATTENTION build for PYTORCH_ROCM_ARCH=gfx908;gfx90a took ~19.5 mins
Without this PR build for PYTORCH_ROCM_ARCH=gfx908;gfx90a took ~19.3 mins

@jithunnair-amd
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased rocm_ck_sdpa onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout rocm_ck_sdpa && git pull --rebase)

@xw285cornell xw285cornell requested a review from albanD January 3, 2025 06:38
@xw285cornell
Copy link
Contributor Author

@albanD any chance you can give an exception to this PR? It's adding the instances of SDPA into the code (we have a similar approach for nvidia's flash attention); and we'll move to pre-built binary rather than from source (for OSS) in the near future.

@facebook-github-bot
Copy link
Contributor

@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor
@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR makes me sad, but is probably fine as temporary solution

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -i

(Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally)

1 similar comment
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -i

(Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: Lint / pr-sanity-checks

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jnolck
Copy link
jnolck commented Jan 6, 2025

This PR makes me sad, but is probably fine as temporary solution

Me too, I thought this pr included a new ck based fa implementation for navi cards. Even getting the old howiejay implementation in here would have been nice.

caaatch22 pushed a commit to caaatch22/pytorch that referenced this pull request Jan 6, 2025
Replace pytorch#138947 for re-import.

Replaces ROCm#1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: pytorch#143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
jithunnair-amd added a commit to ROCm/pytorch that referenced this pull request Feb 19, 2025
Replace pytorch#138947 for re-import.

Replaces #1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: pytorch#143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
(cherry picked from commit 0a94bb4)
jithunnair-amd added a commit to ROCm/pytorch that referenced this pull request Feb 20, 2025
Replace pytorch#138947 for re-import.

Replaces #1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: pytorch#143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
(cherry picked from commit 0a94bb4)
jithunnair-amd added a commit to ROCm/pytorch that referenced this pull request Feb 20, 2025
Replace pytorch#138947 for re-import.

Replaces #1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: pytorch#143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
(cherry picked from commit 0a94bb4)
alugorey added a commit to alugorey/pytorch that referenced this pull request Mar 10, 2025
Replace pytorch#138947 for re-import.

Replaces ROCm#1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: pytorch#143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
alugorey added a commit to alugorey/pytorch that referenced this pull request Mar 20, 2025
Replace pytorch#138947 for re-import.

Replaces ROCm#1592

This PR contains the initial implementation of SDPA with c
E6B3
omposable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: pytorch#143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
alugorey added a commit to alugorey/pytorch that referenced this pull request Mar 24, 2025
Replace pytorch#138947 for re-import.

Replaces ROCm#1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: pytorch#143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: rocm AMD GPU support for Pytorch release notes: distributed (c10d) release notes category skip-pr-sanity-checks topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants
0