8000 DeepSeek V2/V3 with `-mla` option by jukofyork · Pull Request #12725 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

DeepSeek V2/V3 with -mla option #12725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b4c169f
Initial commit with all but the MLA graph code done
jukofyork Apr 2, 2025
10207b4
Fixes
jukofyork Apr 2, 2025
ea3c05b
Just make `uint32_t n_embd_k` and `uint32_t n_embd_v`
jukofyork Apr 2, 2025
1f604a7
First working version
jukofyork Apr 2, 2025
1de077b
Fixed return bug in `DeepseekV2Model`
jukofyork Apr 2, 2025
7f92e7b
Minor fixes
jukofyork Apr 2, 2025
319e3ef
More fixes
jukofyork Apr 2, 2025
ee4b389
Renamed `wv_b` to `wv_decompress` to avoid confusion with `_b` biases
jukofyork Apr 2, 2025
c00cd9e
Better `_compressed` variable names
jukofyork Apr 2, 2025
55ad3a7
Minor comment and variable name fixes
jukofyork Apr 2, 2025
0c86f56
Moved `build_attn_mla` to better location
jukofyork Apr 2, 2025
b0c8a43
Removed `gguf.MODEL_TENSOR.ATTN_K_B` from `prepare_tensors()` for now
jukofyork Apr 2, 2025
8c329bc
Bumped `wkv_b` and `wk_b` to use F32.
jukofyork Apr 2, 2025
68302ee
Use `ggml_mul_mat_set_prec` `GGML_PREC_F32` by default for now
jukofyork Apr 2, 2025
937a48d
Better/shorter variable names and more tidying up of code
jukofyork Apr 2, 2025
1fd0aab
Fixed `kv_cmpr_pe` name
jukofyork Apr 2, 2025
4fb439f
Added `n_embd_head_k` as constant
jukofyork Apr 2, 2025
f9a0ef4
Fixed to use `build_attn_mha()` now
jukofyork Apr 3, 2025
5fe402a
`mla_attn` on then not `flash_attn` so we can run `-fa` for draft models
jukofyork Apr 3, 2025
9b862f9
"flash_attn is not compatible with mla_attn" --> flash_attn off
jukofyork Apr 3, 2025
8e23e0d
Fixed subtle bug caused by `-mla` for speculative models
jukofyork Apr 3, 2025
b384086
Removed need for `v_b_proj` storing. Tidied all ggml_row_size for quants
jukofyork Apr 4, 2025
5dbf99c
Removed both calls to `ggml_mul_mat_set_prec` for MLA and non-MLA cases
jukofyork Apr 4, 2025
f0d514a
Merge branch 'ggml-org:master' into mainline-llama-cpp-master--mla
jukofyork Apr 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Minor fixes
  • Loading branch information
jukofyork committed Apr 2, 2025
commit 7f92e7b6c64dbcdab0808e8915fbcfab7c748a24
11 changes: 11 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,17 @@ struct llm_graph_context {
float kq_scale,
int il) const;

ggml_tensor * build_attn_mla(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wv_b,
ggml_tensor * wo,
ggml_tensor * q_cur, // [n_embd_k, n_tokens, n_head]
ggml_tensor * k_cur, // [n_embd_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_v, n_tokens]
float kq_scale,
int il) const;

llm_graph_input_attn_cross * build_attn_inp_cross() const;

ggml_tensor * build_attn(
Expand Down
4 changes: 3 additions & 1 deletion src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <map>
#include <stdexcept>

#include <inttypes.h>

llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
}

Expand Down Expand Up @@ -95,7 +97,7 @@ bool llama_kv_cache_unified::init(
buft = ggml_backend_cpu_buffer_type();
}

LLAMA_LOG_DEBUG("%s: layer %3ld: n_embd_k = %ld, n_embd_v = %d, dev = %s\n", __func__,
LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %" PRId64 ", n_embd_v = %" PRId64 ", dev = %s\n", __func__,
i, n_embd_k, n_embd_v, dev_name);

ggml_context * ctx = ctx_for_buft(buft);
Expand Down
2 changes: 1 addition & 1 deletion src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9524,7 +9524,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_kv_mla();
auto * inp_attn = llm_graph_input_attn_kv_unified();

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
Expand Down
0