-
Notifications
You must be signed in to change notification settings - Fork 38
FlashMLA-3 for DeepSeek models on CUDA #386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The issue was that I did not change the number of warps used for 3D matrix multiplications (wk_b * kv_cache, MoE), so we ended up using 4 warps for TG. By going to 1 warp in these cases, we get a significant boost in TG performance (tested with DeepSeek-Lite)
Will |
If all attention calculations are done on the Ampere cards, it could work. |
MLA 3 has faster sweep bench speeds for me but unfortunately deepseek 2.5 goes aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa MLA 2 works. |
I gave this a very quick try though the model doesn't fit in VRAM+RAM so pulls almost 6GB/s paging of a Gen5 PCIe NVME drive. This is a 3090TI FE 24GB VRAM GPU. tl;dr;Something seems off with the response using Details
|
PP | TG | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s |
---|---|---|---|---|---|---|
512 | 128 | 0 | 25.314 | 20.23 | 30.434 | 4.21 |
512 | 128 | 512 | 29.349 | 17.45 | 13.466 | 9.51 |
512 | 128 | 1024 | 29.555 | 17.32 | 13.596 | 9.41 |
512 | 128 | 1536 | 29.643 | 17.27 | 13.747 | 9.31 |
512 | 128 | 2048 | 29.202 | 17.53 | 13.819 | 9.26 |
>>> User:
Count from 1 to 10 in French.
>>> Assistant:
Here’s how to count from 1 to 10 in French:
1. **Un** (uhn)
2. **Deux** (du^C
-mla 3
PP | TG | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s |
---|---|---|---|---|---|---|
512 | 128 | 0 | 26.291 | 19.47 | 32.849 | 3.90 |
512 | 128 | 512 | 29.949 | 17.10 | 13.085 | 9.78 |
512 | 128 | 1024 | 30.523 | 16.77 | 13.026 | 9.83 |
512 | 128 | 1536 | 29.763 | 17.20 | 13.095 | 9.77 |
512 | 128 | 2048 | 30.382 | 16.85 | 13.171 | 9.72 |
>>> User:
Count from 1 to 10 in French.
>>> Assistant:
Here are the numbers from 1 to 10 in Please see the 1, 2, 3, 2, 8, 7, 3, 6, 1, 8, 3,
We can see that the string is 8, 8, 2, , 0, 0, 0, 0, 8, 7, 1, 1, 0,0, 0, ^C
OK, thanks for testing. Here is what I get with DeepSeek-Lite for @ubergarm's quantions:
The difference is that Lite has 16 heads, while the big models have 128. So, I guess, something is not quite right with more than 16 heads. |
To be honest, I don't understand the failure. Recap
Observations
So, based on observations, when we use 192,128 CUDA kernel for PP and 576,512 CUDA kernel for TG, it should be working. But it doesn't. |
How many heads does 2.5 have? Maybe there is some difference. It's easier to run and more like qwen in size. I will have to check the MLA 1 output, could be bug in FA. Also had some crash in MLA 2 after using it a while but haven't reproduced yet. |
Should be the same as V3 with 128 heads.
I know. But the DeepSeek models are the only MLA models around. I have verified it works with DeepSeek-Lite (16B parameters, 16 heads). The next step up with MLA are the giant DeepSeek models that I cannot run myself. |
Looks like my theory was correct. On my system MLA 1 also produces issues, probably as soon as FA kicks in. May start out coherent for the first bit of tokens and then descends intooooooooooooooooooosddkkkkkkkkasd |
I can test on ikllamacpp in some hours if I can replicate on deepseek v3 0324 (I'm not home right now) On main llamacpp I tested up to 64K CTX and it was working fine with the PR. If I understand correctly I have to use the latest quants and then use -mla 3 -fa? Main llamacpp uses -mla 2 -fa equivalent? |
The mainline |
I have no v3/r1 quants yet so it may very well work there but not in 2.5. |
Then the conclusion would be that I introduced a bug when porting the mainline PR adding the 576,512 kernel. But it does work with DeepSeek-Lite, so I'm not sure how to debug. |
I can probably download deepseek lite since it's small. This is what I'm running: https://huggingface.co/bartowski/DeepSeek-V2.5-1210-GGUF/tree/main/DeepSeek-V2.5-1210-IQ4_XS |
That would work as a test. |
I just tried to load DeepSeek V3 Q2_K_XL but I get an issue on latest commit. This happens with both -mla 2 -fa and -mla 3 -fa. Not sure if I'm setting a parameter wrongly.
|
@Panchovix You are using a GGUF made for mainline llama.cpp MLA. As I wrote above, you need PR #394, which is an attempt to fix the incompatibility. |
@ikawrakow ah I'm dumb, thanks! Haha gonna try the PR. |
Ok.. baby deepseek v2.0-chat, the ~16b one, right? Sort of inconclusive results. MLA 1 - oooooooooooooooooooooooooo MLA 2/3 + FP16 cache do not exhibit too many issues from a quick test. These quants are months and months old so I'm not sure if anything is wrong with them, I also used IQ4_XS re-ran v2.5 without 8bit cache and mla3 works now. |
I just tested this model, which is near the maximum size I can go. Seems to work perfectly fine with
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 4080, compute capability 8.9, VMM: yes
Log start
main: build = 3673 (4084ca73)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed = 1234
llama_model_loader: additional 2 GGUFs metadata loaded.
llama_model_loader: loaded meta data with 53 key-value pairs and 959 tensors from ./ds2.5/DeepSeek-V2.5-1210-IQ3_XXS-00001-of-00003.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = deepseek2
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.name str = DeepSeek V2.5 1210
llama_model_loader: - kv 3: general.version str = V2.5-1210
llama_model_loader: - kv 4: general.basename str = DeepSeek
llama_model_loader: - kv 5: general.size_label str = 160x14B
llama_model_loader: - kv 6: general.license str = other
llama_model_loader: - kv 7: general.license.name str = deepseek
llama_model_loader: - kv 8: general.license.link str = https://github.com/deepseek-ai/DeepSe...
llama_model_loader: - kv 9: deepseek2.block_count u32 = 60
llama_model_loader: - kv 10: deepseek2.context_length u32 = 163840
llama_model_loader: - kv 11: deepseek2.embedding_length u32 = 5120
llama_model_loader: - kv 12: deepseek2.feed_forward_length u32 = 12288
llama_model_loader: - kv 13: deepseek2.attention.head_count u32 = 128
llama_model_loader: - kv 14: deepseek2.attention.head_count_kv u32 = 128
llama_model_loader: - kv 15: deepseek2.rope.freq_base f32 = 10000,000000
llama_model_loader: - kv 16: deepseek2.attention.layer_norm_rms_epsilon f32 = 0,000001
llama_model_loader: - kv 17: deepseek2.expert_used_count u32 = 6
llama_model_loader: - kv 18: general.file_type u32 = 23
llama_model_loader: - kv 19: deepseek2.leading_dense_block_count u32 = 1
llama_model_loader: - kv 20: deepseek2.vocab_size u32 = 102400
llama_model_loader: - kv 21: deepseek2.attention.q_lora_rank u32 = 1536
llama_model_loader: - kv 22: deepseek2.attention.kv_lora_rank u32 = 512
llama_model_loader: - kv 23: deepseek2.attention.key_length u32 = 192
llama_model_loader: - kv 24: deepseek2.attention.value_length u32 = 128
llama_model_loader: - kv 25: deepseek2.expert_feed_forward_length u32 = 1536
llama_model_loader: - kv 26: deepseek2.expert_count u32 = 160
llama_model_loader: - kv 27: deepseek2.expert_shared_count u32 = 2
llama_model_loader: - kv 28: deepseek2.expert_weights_scale f32 = 16,000000
llama_model_loader: - kv 29: deepseek2.rope.dimension_count u32 = 64
llama_model_loader: - kv 30: deepseek2.rope.scaling.type str = yarn
llama_model_loader: - kv 31: deepseek2.rope.scaling.factor f32 = 40,000000
llama_model_loader: - kv 32: deepseek2.rope.scaling.original_context_length u32 = 4096
llama_model_loader: - kv 33: deepseek2.rope.scaling.yarn_log_multiplier f32 = 0,100000
llama_model_loader: - kv 34: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 35: tokenizer.ggml.pre str = deepseek-llm
llama_model_loader: - kv 36: tokenizer.ggml.tokens arr[str,102400] = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv 37: tokenizer.ggml.token_type arr[i32,102400] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 38: tokenizer.ggml.merges arr[str,99757] = ["Ġ Ġ", "Ġ t", "Ġ a", "i n", "h e...
llama_model_loader: - kv 39: tokenizer.ggml.bos_token_id u32 = 100000
llama_model_loader: - kv 40: tokenizer.ggml.eos_token_id u32 = 100001
llama_model_loader: - kv 41: tokenizer.ggml.padding_token_id u32 = 100001
llama_model_loader: - kv 42: tokenizer.ggml.add_bos_token bool = true
llama_model_loader: - kv 43: tokenizer.ggml.add_eos_token bool = false
llama_model_loader: - kv 44: tokenizer.chat_template str = {% if not add_generation_prompt is de...
llama_model_loader: - kv 45: general.quantization_version u32 = 2
llama_model_loader: - kv 46: quantize.imatrix.file str = /models_out/DeepSeek-V2.5-1210-GGUF/D...
llama_model_loader: - kv 47: quantize.imatrix.dataset str = /training_dir/calibration_datav3.txt
llama_model_loader: - kv 48: quantize.imatrix.entries_count i32 = 716
llama_model_loader: - kv 49: quantize.imatrix.chunks_count i32 = 139
llama_model_loader: - kv 50: split.no u16 = 0
llama_model_loader: - kv 51: split.count u16 = 3
llama_model_loader: - kv 52: split.tensors.count i32 = 959
llama_model_loader: - type f32: 300 tensors
llama_model_loader: - type q5_K: 1 tensors
llama_model_loader: - type iq3_xxs: 597 tensors
llama_model_loader: - type iq3_s: 61 tensors
llm_load_vocab: special tokens cache size = 18
llm_load_vocab: token to piece cache size = 0,6411 MB
llm_load_print_meta: format = GGUF V3 (latest)
llm_load_print_meta: arch = deepseek2
llm_load_print_meta: vocab type = BPE
llm_load_print_meta: n_vocab = 102400
llm_load_print_meta: n_merges = 99757
llm_load_print_meta: vocab_only = 0
llm_load_print_meta: n_ctx_train = 163840
llm_load_print_meta: n_embd = 5120
llm_load_print_meta: n_layer = 60
llm_load_print_meta: n_head = 128
llm_load_print_meta: n_head_kv = 128
llm_load_print_meta: n_rot = 64
llm_load_print_meta: n_swa = 0
llm_load_print_meta: n_swa_pattern = 1
llm_load_print_meta: n_embd_head_k = 192
llm_load_print_meta: n_embd_head_v = 128
llm_load_print_meta: n_gqa = 1
llm_load_print_meta: n_embd_k_gqa = 24576
llm_load_print_meta: n_embd_v_gqa = 16384
llm_load_print_meta: f_norm_eps = 0,0e+00
llm_load_print_meta: f_norm_rms_eps = 1,0e-06
llm_load_print_meta: f_clamp_kqv = 0,0e+00
llm_load_print_meta: f_max_alibi_bias = 0,0e+00
llm_load_print_meta: f_logit_scale = 0,0e+00
llm_load_print_meta: n_ff = 12288
llm_load_print_meta: n_expert = 160
llm_load_print_meta: n_expert_used = 6
llm_load_print_meta: causal attn = 1
llm_load_print_meta: pooling type = 0
llm_load_print_meta: rope type = 0
llm_load_print_meta: rope scaling = yarn
llm_load_print_meta: freq_base_train = 10000,0
llm_load_print_meta: freq_scale_train = 0,025
llm_load_print_meta: n_ctx_orig_yarn = 4096
llm_load_print_meta: rope_finetuned = unknown
llm_load_print_meta: ssm_d_conv = 0
llm_load_print_meta: ssm_d_inner = 0
llm_load_print_meta: ssm_d_state = 0
llm_load_print_meta: ssm_dt_rank = 0
llm_load_print_meta: model type = 236B
llm_load_print_meta: model ftype = IQ3_XXS - 3.0625 bpw
llm_load_print_meta: model params = 235,741 B
llm_load_print_meta: model size = 84,604 GiB (3,083 BPW)
llm_load_print_meta: repeating layers = 84,058 GiB (3,077 BPW, 234,693 B parameters)
llm_load_print_meta: general.name = DeepSeek V2.5 1210
llm_load_print_meta: BOS token = 100000 '<|begin▁of▁sentence|>'
llm_load_print_meta: EOS token = 100001 '<|end▁of▁sentence|>'
llm_load_print_meta: PAD token = 100001 '<|end▁of▁sentence|>'
llm_load_print_meta: LF token = 126 'Ä'
llm_load_print_meta: max token length = 256
llm_load_print_meta: n_layer_dense_lead = 1
llm_load_print_meta: n_lora_q = 1536
llm_load_print_meta: n_lora_kv = 512
llm_load_print_meta: n_ff_exp
9E88
= 1536
llm_load_print_meta: n_expert_shared = 2
llm_load_print_meta: expert_weights_scale = 16,0
llm_load_print_meta: expert_weights_norm = 0
llm_load_print_meta: expert_gating_func = softmax
llm_load_print_meta: rope_yarn_log_mul = 0,1000
llm_load_tensors: ggml ctx size = 0,80 MiB
Tensor blk.1.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.1.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.1.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.2.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.2.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.2.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.3.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.3.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.3.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.4.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.4.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.4.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.5.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.5.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.5.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.6.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.6.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.6.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.7.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.7.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.7.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.8.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.8.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.8.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.9.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.9.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.9.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.10.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.10.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.10.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.11.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.11.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.11.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.12.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.12.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.12.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.13.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.13.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.13.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.14.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.14.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.14.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.15.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.15.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.15.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.16.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.16.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.16.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.17.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.17.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.17.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.18.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.18.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.18.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.19.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.19.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.19.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.20.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.20.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.20.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.21.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.21.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.21.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.22.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.22.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.22.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.23.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.23.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.23.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.24.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.24.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.24.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.25.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.25.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.25.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.26.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.26.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.26.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.27.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.27.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.27.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.28.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.28.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.28.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.29.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.29.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.29.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.30.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.30.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.30.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.31.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.31.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.31.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.32.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.32.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.32.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.33.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.33.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.33.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.34.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.34.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.34.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.35.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.35.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.35.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.36.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.36.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.36.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.37.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.37.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.37.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.38.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.38.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.38.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.39.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.39.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.39.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.40.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.40.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.40.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.41.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.41.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.41.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.42.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.42.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.42.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.43.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.43.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.43.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.44.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.44.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.44.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.45.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.45.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.45.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.46.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.46.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.46.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.47.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.47.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.47.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.48.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.48.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.48.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.49.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.49.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.49.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.50.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.50.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.50.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.51.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.51.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.51.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.52.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.52.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.52.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.53.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.53.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.53.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.54.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.54.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.54.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.55.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.55.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.55.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.56.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.56.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.56.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.57.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.57.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.57.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.58.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.58.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.58.ffn_up_exps.weight buffer type overriden to CPU
Tensor blk.59.ffn_gate_exps.weight buffer type overriden to CPU
Tensor blk.59.ffn_down_exps.weight buffer type overriden to CPU
Tensor blk.59.ffn_up_exps.weight buffer type overriden to CPU
llm_load_tensors: offloading 60 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 61/61 layers to GPU
llm_load_tensors: CPU buffer size = 37343,30 MiB
llm_load_tensors: CPU buffer size = 37866,68 MiB
llm_load_tensors: CPU buffer size = 10656,64 MiB
llm_load_tensors: CPU buffer size = 214,84 MiB
llm_load_tensors: CUDA0 buffer size = 5109,97 MiB
....................................................................................................
============ llm_load_tensors: need to compute 60 wk_b tensors
Computed blk.0.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.1.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.2.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.3.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.4.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.5.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.6.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.7.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.8.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.9.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.10.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.11.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.12.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.13.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.14.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.15.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.16.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.17.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.18.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.19.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.20.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.21.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.22.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.23.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.24.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.25.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.26.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.27.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.28.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.29.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.30.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.31.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.32.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.33.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.34.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.35.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.36.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.37.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.38.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.39.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.40.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.41.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.42.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.43.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.44.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.45.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.46.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.47.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.48.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.49.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.50.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.51.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.52.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.53.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.54.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.55.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.56.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.57.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.58.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
Computed blk.59.attn_v_b.weight as 128 x 512 x 128 and stored in buffer CUDA0
llama_new_context_with_model: n_ctx = 32768
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: mla_attn = 3
llama_new_context_with_model: attn_max_b = 0
llama_new_context_with_model: fused_moe = 0
llama_new_context_with_model: ser = -1, 0
llama_new_context_with_model: freq_base = 10000,0
llama_new_context_with_model: freq_scale = 0,025
llama_kv_cache_init: layer 0: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 1: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 2: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 3: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 4: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 5: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 6: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 7: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 8: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 9: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 10: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 11: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 12: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 13: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 14: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 15: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 16: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 17: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 18: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 19: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 20: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 21: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 22: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 23: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 24: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 25: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 26: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 27: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 28: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 29: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 30: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 31: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 32: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 33: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 34: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 35: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 36: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 37: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 38: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 39: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 40: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 41: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 42: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 43: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 44: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 45: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 46: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 47: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 48: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 49: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 50: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 51: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 52: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 53: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 54: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 55: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 56: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 57: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 58: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: layer 59: n_embd_head_qk_rope = 64, kv_lora_rank = 512
llama_kv_cache_init: CUDA0 KV buffer size = 2160,00 MiB
llama_new_context_with_model: KV self size = 2160,00 MiB, c^KV (f16): 2160,00 MiB, kv^T: not used
llama_new_context_with_model: CUDA_Host output buffer size = 0,39 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 6346,00 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 74,01 MiB
llama_new_context_with_model: graph nodes = 3290
llama_new_context_with_model: graph splits = 179
main: chat template example: You are a helpful assistant
|
OK, PR #400 should fix quantized KV cache. |
Yes this seems to work in my quick testing of big DeepSeek-R1-IQ2_K_R4 hybrid CPU+GPU on my local rig for both
However, I noticed for both Thanks for working through all the combinations! |
Thanks for testing. I'm not sure if the |
I've never gotten |
SER should hopefully work correctly now, see PR #404 |
I just tried out PR404 which is now Also got a segfault when hitting 👈 Full command logs
It could be this model is too small, all attention layers are Q8_0 for GPU, and for CPU ffn_down is IQ3_K_R4, ffn_(gate|up) are IQ2_K_R4. Still works okay without EDIT |
Deepseek 2.5 seems to work with q_8, tg/pp is slightly faster than F16. Unfortunately sometimes a GPU gets stuck at 100% in task manager and the bench or server halts then sits. GPU power draw not consistent with 100% usage of course. It could be due to my undervolting or something else? F16 sweep completes successfully and is definitely "heavier" on resources so I'm not sure anymore. |
OK, thanks. The PR fixes things for me, but it seems there is still a bug lurking somewhere. I'll keep looking. |
There have been reports about problems with FA also in mainline. As I took the DeepSeek implementation from there, I guess |
Now that you mention it, that's the kind of error I'd get on llama-server. It would eventually fail and segfault there with synchronization listed as the fault. I assumed it was due to me undervolting. Setting lower max gpu clock along with the clock offset (only way to do it on linux) caused it to happen less often. Perhaps it's only coincidental. Haven't yet tested EXL2 tensor parallel and it's much higher GPU load on the same settings. If it dumps on me again, I'll try to grab the error. |
This PR in mainline
llama.cpp
is a CUDA flash attention (FA) implementation that also works with K head size of 576 and V head size of 512 as required for DeepSeek models with MLA enabled. Caveat: it only works on Ampere or newer Nvidia GPUs.I have taken it and adapted it to the
ik_llama.cpp
environment, but only using it for the 576,512 case (for other head sizes it is slower than the existing implementation). This allows to finally have-mla 3 -fa
, which is the recommended option for DeepSeek models when running on the CPU, also work on CUDA.This results in massive DeepSeek TG performance gains for long contexts when self attention is computed on the GPU. The graph below shows an example for
Q4_0
quantized DeepSeek-Lite model that I can fully load on my RTX-4080 GPU with 16 GB VRAM. I have used u-batches of 4096 withsweep-bench
to more quickly cover the context range of up to 65k tokens. The main branch results use-mla 2 -fa
, the PR uses-mla 3 -fa
. No FA (which is what happens for TG whenmla = 2
) is slightly faster with zero context. I have considered special-casingN_KV <= 256
, but than decided against it. Less than 256 tokens in the KV cache has no relevance in actual usage other than bragging about the great performance one got on Reddit and elsewhere after saying 'Hello' to the LLM and checking the performance stats. At 60k tokens the PR is 3X faster than the main branch for TG!I'm not including a comparison to the mainline PR as it has not been merged, so things can change (but for the curious, mainline TG is slightly faster for small
N_KV
and slower for largeN_KV
, PP continues to be far behindik_llama.cpp
).Main branch, mla = 2, fa
PR, mla = 3, fa
The PR also adds a tweak to matrix-vector multiplications that leads to minor TG performance gains for MoE models other than DeepSeek. As an example, the next graph shows TG performance for
IQ2_XS
quantized Qwen3-30B-A3B (so it fully loads on my 16 GB GPU) using-fmoe -fa -ub 2048
.Testing with DeepSeek-V3/R1 will be greatly appreciated. Very few can run these models fully offloaded to the GPU, but I do expect non-negligible performance gains for long context also for hybrid GPU/CPU inference (where self attention is computed on the GPU). Checking that it works correctly is of course most important.