8000 Large performance drop when using pipeline parallelism and layer splitting on multiple GPUs · Issue #13751 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Large performance drop when using pipeline parallelism and layer splitting on multiple GPUs #13751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
matbrez opened this issue May 24, 2025 · 7 comments · Fixed by #13814
Closed
Assignees

Comments

@matbrez
Copy link
matbrez commented May 24, 2025

Problem description

The default value for GGML_SCHED_MAX_COPIES is 4. With that value -sm layer performs significantly worse than -sm none
Setting GGML_SCHED_MAX_COPIES to 1 brings -sm layer performance up to the level of -sm none and doesn't seem to otherwise negatively impact performance in this use case.

Benchmarks

CMake options: -DGGML_CUDA=ON, -sm layer performs worse than -sm none

> llama-bench.exe -m Qwen3-32B-Q4_K_M.gguf -ngl 99 -sm layer 
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |           pp512 |       3306.84 ± 3.18 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |           tg128 |         40.06 ± 0.10 |

build: b775345d (5470)
> llama-bench.exe -m Qwen3-32B-Q4_K_M.gguf -ngl 99 -sm none
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl |    sm |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |  none |           pp512 |      3315.12 ± 14.63 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |  none |           tg128 |         61.38 ± 0.13 |

build: b775345d (5470)

CMake options: -DGGML_SCHED_MAX_COPIES=1 -DGGML_CUDA=ON, -sm layer performs the same as -sm none

> llama-bench.exe -m Qwen3-32B-Q4_K_M.gguf -ngl 99 -sm layer
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |           pp512 |       3314.12 ± 8.22 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |           tg128 |         60.46 ± 0.23 |

build: b775345d (5470)
> llama-bench.exe -m Qwen3-32B-Q4_K_M.gguf -ngl 99 -sm none
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl |    sm |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |  none |           pp512 |      3328.29 ± 10.94 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |  none |           tg128 |         61.30 ± 0.07 |

build: b775345d (5470)

Additional information

Using --override-tensors also seems to have the effect of disabling pipeline parallelism even in builds with DGGML_SCHED_MAX_COPIES=4. When using -v the line llama_context: pipeline parallelism enabled (n_copies=4) is not printed when --override-tensors is used.

> llama-bench.exe -m Qwen3-32B-Q4_K_M.gguf -ngl 99 -sm layer -ot "blk\.(0|1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17|18|19|20|21|22|23|24|25|26|27|28|29|30|31|32)\.=CUDA0;blk\.(33|34|35|36|37|38|39|40|41|42|43|44|45|46|48|49|50|51|52|53|54|55|56|57|58|59|60|61|62|63)\.=CUDA1"
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | ot                    |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------------- | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 | blk\.(0|1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17|18|19|20|21|22|23|24|25|26|27|28|29|30|31|32)\.=CUDA0;blk\.(33|34|35|36|37|38|39|40|41|42|43|44|45|46|48|49|50|51|52|53|54|55|56|57|58|59|60|61|62|63)\.=CUDA1 |           pp512 |      3311.54 ± 12.45 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 | blk\.(0|1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17|18|19|20|21|22|23|24|25|26|27|28|29|30|31|32)\.=CUDA0;blk\.(33|34|35|36|37|38|39|40|41|42|43|44|45|46|48|49|50|51|52|53|54|55|56|57|58|59|60|61|62|63)\.=CUDA1 |           tg128 |         60.80 ± 0.07 |

build: b775345d (5470)

The model used is https://huggingface.co/Qwen/Qwen3-32B-GGUF/blob/main/Qwen3-32B-Q4_K_M.gguf but other model architectures also exhibit the same behaviour. I tested qwen3, qwen3moe, llama, and gemma3.

Disabling pipeline parallelism also improves performance for models that don't fit on a single GPU in the first place. For example https://huggingface.co/Qwen/Qwen3-235B-A22B-GGUF/tree/main/Q4_K_M goes from 25t/s to 60t/s.

All tests were done on Windows. Version of the CUDA Toolkit is 12.9.

> llama-cli.exe --version
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
version: 5470 (b775345d)
built with MSVC 19.42.34435.0 for x64
@quasar-of-mikus
Copy link
quasar-of-mikus commented May 24, 2025

Same here. build: cf0a43b (5361), default value for DGGML_SCHED_MAX_COPIES
-sm layer

| model                          |       size |     params | backend    | ngl | ts           |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------ | --------------: | -------------------: |
| llama 13B Q8_0                 |  12.12 GiB |    12.25 B | CUDA       |  99 | 20.00/20.00  |           pp512 |      3178.42 ± 10.08 |
| llama 13B Q8_0                 |  12.12 GiB |    12.25 B | CUDA       |  99 | 20.00/20.00  |           tg128 |         51.48 ± 0.02 |

-sm none

| model                          |       size |     params | backend    | ngl |    sm | ts           |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------ | --------------: | -------------------: |
| llama 13B Q8_0                 |  12.12 GiB |    12.25 B | CUDA       |  99 |  none | 20.00/20.00  |           pp512 |      3190.36 ± 10.45 |
| llama 13B Q8_0                 |  12.12 GiB |    12.25 B | CUDA       |  99 |  none | 20.00/20.00  |           tg128 |         59.04 ± 0.04 |

-sm layer -ot blk\.1\.=CUDA0

| model                          |       size |     params | backend    | ngl | ts           | ot                    |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------ | --------------------- | --------------: | -------------------: |
| llama 13B Q8_0                 |  12.12 GiB |    12.25 B | CUDA       |  99 | 20.00/20.00  | blk\.1\.=CUDA0        |           pp512 |       3178.07 ± 9.22 |
| llama 13B Q8_0                 |  12.12 GiB |    12.25 B | CUDA       |  99 | 20.00/20.00  | blk\.1\.=CUDA0        |           tg128 |         58.26 ± 0.09 |

-sm none -ot blk\.1\.=CUDA0

| model                          |       size |     params | backend    | ngl |    sm | ts           | ot                    |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------ | --------------------- | --------------: | -------------------: |
| llama 13B Q8_0                 |  12.12 GiB |    12.25 B | CUDA       |  99 |  none | 20.00/20.00  | blk\.1\.=CUDA0        |           pp512 |      3195.58 ± 17.70 |
| llama 13B Q8_0                 |  12.12 GiB |    12.25 B | CUDA       |  99 |  none | 20.00/20.00  | blk\.1\.=CUDA0        |           tg128 |         59.04 ± 0.03 |
@echo off
set CUDA_VISIBLE_DEVICES=0,1
llama-bench.exe ^
-m "T:\models\mistral-nemo-storywriter-12b-240918-Q8_0.gguf" ^
-ts 20/20 ^
-ngl 99 ^
-sm none ^
-ot blk\.1\.=CUDA0

@slaren
Copy link
Member
slaren commented May 26, 2025

Likely a Windows issue. Try disabling "Hardware-accelerated GPU scheduling" under graphics settings.

@matbrez
Copy link
Author
matbrez commented May 26, 2025

Disabling "Hardware-accelerated GPU scheduling" made no difference.

The same issue exists on Linux although the difference is smaller.

Without -DGGML_SCHED_MAX_COPIES=1

$ ./build-4/bin/llama-bench -m Qwen3-32B-Q4_K_M.gguf -sm layer
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |           pp512 |       3447.32 ± 9.47 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |           tg128 |         52.97 ± 0.15 |

build: cdf94a18 (5501)
$ ./build-4/bin/llama-bench -m Qwen3-32B-Q4_K_M.gguf -sm none
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl |    sm |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |  none |           pp512 |      3404.88 ± 11.57 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |  none |           tg128 |         63.90 ± 0.04 |

build: cdf94a18 (5501)

With -DGGML_SCHED_MAX_COPIES=1

$ ./build-1/bin/llama-bench -m Qwen3-32B-Q4_K_M.gguf -sm layer
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |           pp512 |       3434.38 ± 5.77 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |           tg128 |         63.66 ± 0.03 |

build: cdf94a18 (5501)
$ ./build-1/bin/llama-bench -m Qwen3-32B-Q4_K_M.gguf -sm none
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
  Device 1: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl |    sm |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | --------------: | -------------------: |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |  none |           pp512 |      3383.85 ± 15.83 |
| qwen3 32B Q4_K - Medium        |  18.40 GiB |    32.76 B | CUDA       |  99 |  none |           tg128 |         63.91 ± 0.08 |

build: cdf94a18 (5501)

@slaren
Copy link
Member
slaren commented May 26, 2025

Are you using native linux or WSL?

@matbrez
Copy link
Author
matbrez commented May 26, 2025

Native.

$ uname -a
Linux u 6.11.0-26-generic #26~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Apr 17 19:20:47 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
$ nvidia-smi
Tue May 27 03:09:02 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.51.03              Driver Version: 575.51.03      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX PRO 6000 Blac...    Off |   00000000:01:00.0  On |                  Off |
| 30%   35C    P8             16W /  600W |     245MiB /  97887MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX PRO 6000 Blac...    Off |   00000000:03:00.0 Off |                  Off |
| 30%   31C    P8             15W /  600W |      15MiB /  97887MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

@slaren
Copy link
Member
slaren commented May 27, 2025

#13814 should fix this.

@matbrez
Copy link
Author
matbrez commented May 27, 2025

I can confirm that #13814 fixes the issue. I am getting expected performance in the layer + pp case on that branch. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
0