-
Notifications
You must be signed in to change notification settings - Fork 12k
vulkan: KHR_coopmat flash attention #13506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
On the Radeon RX 7800 XT performance on the radv driver is now slightly faster than without FA. ggml_vulkan: 0 = AMD Radeon RX 7800 XT (RADV NAVI32) (radv) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
On the amdvlk driver the performance is also quite close to non FA. ggml_vulkan: 0 = AMD Radeon RX 7800 XT (AMD open-source driver) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 32768 | int dot: 1 | matrix cores: KHR_coopmat
|
AMD 9070XT, master -fa 1
build: de4c07f (5359) AMD 9070XT, PR -fa 1
build: 3dc90d7 (5344) AMD 9070XT, pr -fa 0
build: 3dc90d7 (5344) Intel Arc A770, master -fa 1
build: de4c07f (5359) Intel Arc A770, pr -fa 1
build: 3dc90d7 (5344) Intel Arc A770, pr -fa 0
build: 3dc90d7 (5344) With the -fa 0 flag, the performance of PR and master is identical, so I've only provided the values for PR with -fa 0. |
RTX 3090Coopmat2:
Coopmat1:
Scalar:
Other GPUsI sadly don't have an AMD GPU with coopmat support, so I can't test that. But I know it supports the 16x16x16 shape. Intel does not, but that is not really a problem until coopmat matmul starts working on Intel. Here are the shapes supported by my Intel A770:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This shader uses coopmat1 to do the Q*K^T multiply. The P*V multiply is more difficult for various reasons so I haven't done it. Performance for this shader is around 2.5x better than for the scalar shader when doing prompt processing. Some of the benefit may be from other optimizations like staging through shared memory, or splitting by rows.
This shader uses coopmat1 to do the Q*K^T multiply. The P*V multiply is more difficult for various reasons so I haven't done it. Performance for this shader is around 2.5x better than for the scalar shader when doing prompt processing. Some of the benefit may be from other optimizations like staging through shared memory, or splitting by rows.
This shader uses coopmat1 to do the QK^T multiply. The PV multiply is more difficult for various reasons so I haven't done it. Performance for this shader is around 2.5x better than for the scalar shader when doing prompt processing. Some of the benefit may be from other optimizations like staging through shared memory, or splitting by rows.
This change needs 16x16x16 shape and subgroup size of 32 to be supported for it to use the new paths. I think this is supported at least for AMD and maybe Intel, but I don't have the hardware to check that.
Testing prompt processing with KHR_coopmat on RTX 4070, FA is a few percent faster than no FA now (rather than being several percent slower). Hopefully this is true across other hardware.
I want to move some of the duplicated code into a header file, but will do that in a later change.