8000 vulkan: KHR_coopmat flash attention by jeffbolznv · Pull Request #13506 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

vulkan: KHR_coopmat flash attention #13506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged

Conversation

jeffbolznv
Copy link
Collaborator

This shader uses coopmat1 to do the QK^T multiply. The PV multiply is more difficult for various reasons so I haven't done it. Performance for this shader is around 2.5x better than for the scalar shader when doing prompt processing. Some of the benefit may be from other optimizations like staging through shared memory, or splitting by rows.

This change needs 16x16x16 shape and subgroup size of 32 to be supported for it to use the new paths. I think this is supported at least for AMD and maybe Intel, but I don't have the hardware to check that.

Testing prompt processing with KHR_coopmat on RTX 4070, FA is a few percent faster than no FA now (rather than being several percent slower). Hopefully this is true across other hardware.

I want to move some of the duplicated code into a header file, but will do that in a later change.

@jeffbolznv jeffbolznv requested a review from 0cc4m May 13, 2025 12:42
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels May 13, 2025
@daniandtheweb
Copy link
Contributor

On the Radeon RX 7800 XT performance on the radv driver is now slightly faster than without FA.

ggml_vulkan: 0 = AMD Radeon RX 7800 XT (RADV NAVI32) (radv) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat

model size params backend ngl fa test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 0 pp512 1236.34 ± 24.93
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 0 tg128 111.82 ± 0.75
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 1 pp512 1250.67 ± 9.72
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 1 tg128 114.26 ± 0.54

On the amdvlk driver the performance is also quite close to non FA.

ggml_vulkan: 0 = AMD Radeon RX 7800 XT (AMD open-source driver) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 32768 | int dot: 1 | matrix cores: KHR_coopmat

model size params backend ngl fa test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 0 pp512 2095.51 ± 12.54
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 0 tg128 97.08 ± 2.10
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 1 pp512 1960.76 ± 9.44
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 1 tg128 95.36 ± 0.11

@characharm
Copy link
Contributor
characharm commented May 13, 2025

AMD 9070XT, master -fa 1

model size params backend ngl fa test t/s
qwen2 14B Q4_0 7.95 GiB 14.77 B RPC,Vulkan 99 1 pp512 1044.03 ± 5.71
qwen2 14B Q4_0 7.95 GiB 14.77 B RPC,Vulkan 99 1 tg128 46.39 ± 0.27

build: de4c07f (5359)

AMD 9070XT, PR -fa 1

model size params backend ngl fa test t/s
qwen2 14B Q4_0 7.95 GiB 14.77 B Vulkan 99 1 pp512 1745.20 ± 28.09
qwen2 14B Q4_0 7.95 GiB 14.77 B Vulkan 99 1 tg128 50.03 ± 0.22

build: 3dc90d7 (5344)

AMD 9070XT, pr -fa 0

model size params backend ngl test t/s
qwen2 14B Q4_0 7.95 GiB 14.77 B Vulkan 99 pp512 2007.45 ± 5.23
qwen2 14B Q4_0 7.95 GiB 14.77 B Vulkan 99 tg128 59.08 ± 0.21

build: 3dc90d7 (5344)

Intel Arc A770, master -fa 1

model size params backend ngl fa test t/s
qwen2 14B Q4_0 7.95 GiB 14.77 B RPC,Vulkan 99 1 pp512 172.49 ± 0.10
qwen2 14B Q4_0 7.95 GiB 14.77 B RPC,Vulkan 99 1 tg128 20.18 ± 0.02

build: de4c07f (5359)

Intel Arc A770, pr -fa 1

model size params backend ngl fa test t/s
qwen2 14B Q4_0 7.95 GiB 14.77 B Vulkan 99 1 pp512 172.14 ± 0.19
qwen2 14B Q4_0 7.95 GiB 14.77 B Vulkan 99 1 tg128 20.17 ± 0.01

build: 3dc90d7 (5344)

Intel Arc A770, pr -fa 0

model size params backend ngl test t/s
qwen2 14B Q4_0 7.95 GiB 14.77 B Vulkan 99 pp512 368.90 ± 1.04
qwen2 14B Q4_0 7.95 GiB 14.77 B Vulkan 99 tg128 25.00 ± 0.02

build: 3dc90d7 (5344)

With the -fa 0 flag, the performance of PR and master is identical, so I've only provided the values for PR with -fa 0.

@0cc4m
Copy link
Collaborator
0cc4m commented May 14, 2025

RTX 3090

Coopmat2:

model size params backend ngl fa test t/s
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 pp512 4278.92 ± 71.46
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 tg128 103.23 ± 5.64
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 pp512 4569.20 ± 11.59
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 tg128 105.69 ± 0.16

Coopmat1:

model size params backend ngl fa test t/s
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 pp512 3156.83 ± 7.41
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 tg128 103.27 ± 5.65
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 pp512 3166.53 ± 24.35
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 tg128 101.58 ± 0.14

Scalar:

model size params backend ngl fa test t/s
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 pp512 1925.90 ± 7.29
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 0 tg128 103.92 ± 4.39
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 pp512 1912.57 ± 4.31
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 99 1 tg128 104.93 ± 2.07

Other GPUs

I sadly don't have an AMD GPU with coopmat support, so I can't test that. But I know it supports the 16x16x16 shape. Intel does not, but that is not really a problem until coopmat matmul starts working on Intel.

Here are the shapes supported by my Intel A770:

M: 8 N: 8 K: 16 A: Float16 B: Float16 C: Float32 Result: Float32 saturatingAccumulation: 0 scope: Subgroup
M: 8 N: 8 K: 32 A: Sint8 B: Sint8 C: Sint32 Result: Sint32 saturatingAccumulation: 0 scope: Subgroup
M: 8 N: 8 K: 32 A: Uint8 B: Uint8 C: Uint32 Result: Uint32 saturatingAccumulation: 0 scope: Subgroup

Copy link
Collaborator
@0cc4m 0cc4m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This shader uses coopmat1 to do the Q*K^T multiply. The P*V multiply is more
difficult for various reasons so I haven't done it. Performance for this
shader is around 2.5x better than for the scalar shader when doing prompt
processing. Some of the benefit may be from other optimizations like staging
through shared memory, or splitting by rows.
@0cc4m 0cc4m merged commit 24e86ca into ggml-org:master May 14, 2025
44 checks passed
Silver267 pushed a commit to Silver267/llama.cpp that referenced this pull request May 14, 2025
This shader uses coopmat1 to do the Q*K^T multiply. The P*V multiply is more
difficult for various reasons so I haven't done it. Performance for this
shader is around 2.5x better than for the scalar shader when doing prompt
processing. Some of the benefit may be from other optimizations like staging
through shared memory, or splitting by rows.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0