8000 vulkan: readd GGML_VULKAN_PERF by netrunnereve · Pull Request #13761 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

vulkan: readd GGML_VULKAN_PERF #13761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

netrunnereve
Copy link
Collaborator

This readds the GGML_VULKAN_PERF feature after it got removed in #9118. It's set up to submit the ops individually like GGML_VULKAN_CHECK_RESULTS (I think I did that right).

./bin/llama-bench -m llama-2-7b.Q4_0.gguf -r 1 -p 512 -n 2
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon RX 470 Graphics (RADV POLARIS10) (radv) | uma: 0 | fp16: 0 | warp size: 64 | shared memory: 65536 | int dot: 0 | matrix cores: none
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
----------------
Vulkan Timings:
ADD: 64 x 263.644 ms
CONT: 32 x 211.384 ms
CPY: 64 x 1339.26 ms
GET_ROWS: 2 x 87.714 ms
MUL: 97 x 438.105 ms
MUL_MAT m=11008 n=512 k=4096: 62 x 14361.3 ms
MUL_MAT m=128 n=512 k=512: 32 x 2424.7 ms
MUL_MAT m=4096 n=512 k=11008: 31 x 15690.5 ms
MUL_MAT m=4096 n=512 k=4096: 128 x 6855.94 ms
MUL_MAT m=512 n=512 k=128: 32 x 2358.41 ms
MUL_MAT_VEC m=11008 k=4096: 2 x 200.069 ms
MUL_MAT_VEC m=32000 k=4096: 1 x 727.618 ms
MUL_MAT_VEC m=4096 k=11008: 1 x 211.537 ms
R
8000
MS_NORM: 65 x 214.698 ms
ROPE: 64 x 362.017 ms
SILU: 32 x 330.783 ms
SOFT_MAX: 32 x 870.034 ms
----------------
Vulkan Timings:
ADD: 64 x 277.48 ms
CONT: 32 x 253.444 ms
CPY: 64 x 1323.65 ms
GET_ROWS: 2 x 69.106 ms
MUL: 97 x 441.004 ms
MUL_MAT m=11008 n=512 k=4096: 62 x 14390.8 ms
MUL_MAT m=128 n=512 k=512: 32 x 2845.28 ms
MUL_MAT m=4096 n=512 k=11008: 31 x 16561.4 ms
MUL_MAT m=4096 n=512 k=4096: 128 x 7122.14 ms
MUL_MAT m=512 n=512 k=128: 32 x 2458.84 ms
MUL_MAT_VEC m=11008 k=4096: 2 x 186.524 ms
MUL_MAT_VEC m=32000 k=4096: 1 x 702.008 ms
MUL_MAT_VEC m=4096 k=11008: 1 x 198.647 ms
RMS_NORM: 65 x 212.158 ms
ROPE: 64 x 303.754 ms
SILU: 32 x 311.434 ms
SOFT_MAX: 32 x 1235.61 ms
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |           pp512 |        187.09 ± 0.00 |
----------------
Vulkan Timings:
ADD: 64 x 36.659 ms
CONT: 32 x 35.54 ms
CPY: 64 x 35.763 ms
GET_ROWS: 2 x 33.407 ms
MUL: 97 x 37.877 ms
MUL_MAT_VEC m=11008 k=4096: 64 x 180.476 ms
MUL_MAT_VEC m=128 k=32: 32 x 50.035 ms
MUL_MAT_VEC m=32 k=128: 32 x 39.315 ms
MUL_MAT_VEC m=32000 k=4096: 1 x 666.888 ms
MUL_MAT_VEC m=4096 k=11008: 32 x 191.914 ms
MUL_MAT_VEC m=4096 k=4096: 128 x 96.082 ms
RMS_NORM: 65 x 46.071 ms
ROPE: 64 x 37.849 ms
SILU: 32 x 34.388 ms
SOFT_MAX: 32 x 35.824 ms
----------------
Vulkan Timings:
ADD: 64 x 39.125 ms
CONT: 32 x 36.44 ms
CPY: 64 x 35.159 ms
GET_ROWS: 2 x 33.951 ms
MUL: 97 x 35.89 ms
MUL_MAT_VEC m=11008 k=4096: 64 x 179.759 ms
MUL_MAT_VEC m=128 k=32: 32 x 56.564 ms
MUL_MAT_VEC m=32 k=128: 32 x 37.74 ms
MUL_MAT_VEC m=32000 k=4096: 1 x 664.134 ms
MUL_MAT_VEC m=4096 k=11008: 32 x 191.265 ms
MUL_MAT_VEC m=4096 k=4096: 128 x 95.2 ms
RMS_NORM: 65 x 43.697 ms
ROPE: 64 x 36.036 ms
SILU: 32 x 34.948 ms
SOFT_MAX: 32 x 36.142 ms
----------------
Vulkan Timings:
ADD: 64 x 38.788 ms
CONT: 32 x 36.099 ms
CPY: 64 x 37.976 ms
GET_ROWS: 2 x 33.572 ms
MUL: 97 x 40.582 ms
MUL_MAT_VEC m=11008 k=4096: 64 x 181.708 ms
MUL_MAT_VEC m=128 k=32: 32 x 50.636 ms
MUL_MAT_VEC m=32 k=128: 32 x 40.164 ms
MUL_MAT_VEC m=32000 k=4096: 1 x 666.921 ms
MUL_MAT_VEC m=4096 k=11008: 32 x 198.705 ms
MUL_MAT_VEC m=4096 k=4096: 128 x 97.665 ms
RMS_NORM: 65 x 46.091 ms
ROPE: 64 x 41.26 ms
SILU: 32 x 42.978 ms
SOFT_MAX: 32 x 36.567 ms
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |             tg2 |         17.35 ± 0.00 |

build: 596669e3 (5471)

The tests are passing with GGML_VULKAN_PERF turned on.

@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels May 24, 2025
@jeffbolznv
Copy link
Collaborator

The changes look reasonable to me. But is there a reason to prefer this over test-backend-ops perf? Having the results in TFLOPS or GB/s is easier to interpret.

@netrunnereve
Copy link
Collaborator Author

But is there a reason to prefer this over test-backend-ops perf?

This basically summarizes what's happening inside a real model whereas test-backend-ops is more for checking specific ops. I get to see the exact ops that are run along with the matrix sizes and all that. It's also a easy way to make sure that most of the time is spent in mat mul or mat vec rather than doing something else.

@0cc4m 0cc4m self-requested a review May 25, 2025 08:17
@jeffbolznv
Copy link
Collaborator

I think the overhead of submitting each node will affect the results, particularly for small nodes. Timestamp queries should be able to give more precise results. Can you try this branch? master...jeffbolznv:llama.cpp:query_pool. (I also noticed the units were wrong while doing this).

@netrunnereve
Copy link
Collaborator Author

I basically tested my PR by running test-backend-ops and comparing the numbers. So for mul_mat it looks something like this:

  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):             ----------------
Vulkan Timings:
MUL_MAT m=4096 n=512 k=14336: 1 x 21641.5 us
...
Vulkan Timings:
MUL_MAT m=4096 n=512 k=14336: 2 x 24416.2 us
----------------
Vulkan Timings:
MUL_MAT m=4096 n=512 k=14336: 2 x 21720.2 us
          46 runs - 22667.09 us/run -  60.13 GFLOP/run -   2.65 TFLOPS

I'm not sure why it shows a 2x at the end but the seconds per run are pretty close. For short prompts or text generation it definitely runs way slower than it should, and as you said it's probably due to the fact that we're submitting one op at a time.

With your branch everything runs at their normal speeds but the numbers are way off. I would be ecstatic if my GPU could do a 4096x512x14336 matrix in 618 us but sadly that's not the case 😞

  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):             ----------------
Vulkan Timings:
MUL_MAT m=4096 n=512 k=14336: 1 x 536.504 us
...
Vulkan Timings:
MUL_MAT m=4096 n=512 k=14336: 2 x 618.112 us
          46 runs - 22581.57 us/run -  60.13 GFLOP/run -   2.66 TFLOPS

@jeffbolznv
Copy link
Collaborator

Strange, I get accurate looking numbers from the timestamps. Can you check if it's any better with eComputeShader or eBottomOfPipe? It sounds like a driver bug...

@netrunnereve
Copy link
Collaborator Author

Can you check if it's any better with eComputeShader or eBottomOfPipe?

Nope nothing changed when I replaced eAllCommands with those.

It sounds like a driver bug

It's worth mentioning that my Intel integrated chip also has those really low timing numbers, so if it's a bug its a pretty widespread one. Then again this might be a mesa thing.

@jeffbolznv
Copy link
Collaborator

Oh, I forgot to multiply by the timestampPeriod (which is 1.0 on NVIDIA), maybe that explains it. Can you try the latest again?

@jeffbolznv
Copy link
Collaborator

Sorry, I failed to push the branch when I commented a half hour ago, it's updated now.

@0cc4m
Copy link
Collaborator
0cc4m commented May 26, 2025

Yeah, that was it. timestampPeriod is 40 for AMD and 52.0833 for Intel. Now the numbers look reasonable.

Interestingly, my "crude" method of measuring still gives a correct general direction, at least for MUL_MAT(_VEC). For smaller ops it overestimates a lot due to submission/synchronization overhead. Good idea to use timestamp queries, I hadn't even heard of that.

The only thing that you could improve is to leave out noops, they don't provide any useful data (always something close to 1us).

@jeffbolznv
Copy link
Collaborator

I didn't avoid inserting timestamps for the noops because it keeps the logic simple for finding the before/after timestamps. But I can probably use ggml_vk_is_empty before passing it to the logger.

@jeffbolznv
Copy link
Collaborator

ok, I've filtered them out in the latest commit. If we want to go ahead with the timestamp queries, I'll create a PR for it.

@0cc4m
Copy link
Collaborator
0cc4m commented May 26, 2025

It's definitely a smarter way of doing this, so if it's alright with @netrunnereve it should be preferred. Let's also wait for them to test it.

@netrunnereve
Copy link
Collaborator Author

It works great now, let me close this and @jeffbolznv can go submit a new PR. I also agree that it's smarter to get the data from the driver without affecting how the ops are run.

BTW I find it strange that "x" got replace with "in", as that implies that the time indicated is the total time needed to complete all the ops. Like 10 in 100us implies that the entire block of 10 occurrences finished in 100us, whereas 10 x 100us implies that each occurrence completed in 100us.

@netrunnereve netrunnereve deleted the vulkan_perf branch May 26, 2025 21:45
@jeffbolznv
Copy link
Collaborator

You're right, I thought it was a total, but it's an average. I'll revert that part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0