8000 kv-cache : separate recurrent vs non-recurrent impl by ggerganov · Pull Request #12799 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

kv-cache : separate recurrent vs non-recurrent impl #12799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
22bda48
kv-cache : serparate recurrent vs non-recurrent impl (wip)
ggerganov Apr 7, 2025
8145799
kv-cache : init -> contructor + add llama_memory_params
ggerganov Apr 15, 2025
49aa8b8
kv-cache : fix callback reference
ggerganov Apr 15, 2025
838b3cc
context : llama_kv_cache -> llama_memory_i
ggerganov Apr 17, 2025
8e4d3ba
context : move memory creation logic to model
ggerganov Apr 17, 2025
7fec081
llama : remove reference of memory during encode
ggerganov Apr 17, 2025
59af92b
kv-cache : hide padding details in the implementation
ggerganov Apr 23, 2025
6413b93
kv-cache : add ubatch_next()
ggerganov Apr 23, 2025
e869515
context : simplify sbatch logic
ggerganov Apr 23, 2025
ae2cd00
kv-cache : hide defrag logic in the implementation
ggerganov Apr 23, 2025
fdb7206
context : hide kv cache details in implementation
ggerganov Apr 23, 2025
13d69a5
build : fix
ggerganov Apr 23, 2025
5ef7559
cont : another fix
ggerganov Apr 23, 2025
6b50ba7
kv-cache : simplify interface (wip)
ggerganov Apr 24, 2025
cb02ac8
kv-cache : use separate KV cell structs for unified/recurrent
ggerganov Apr 24, 2025
f584750
kv-cache : clean-up
ggerganov Apr 24, 2025
458f2a5
model : better llama_model::create_model() signature
ggerganov Apr 24, 2025
92e626b
kv-cache : fix recurrent seq_rm()
ggerganov Apr 25, 2025
43cbf38
kv-cache : replace `struct callbacks` with `llama_model &`
ggerganov Apr 30, 2025
6619832
kv-cache : replace `struct graph_params` with `llama_context &`
ggerganov Apr 30, 2025
95a9f8b
kv-cache : fix offload check
ggerganov Apr 30, 2025
8737e65
context : avoid passing unique_ptr
ggerganov Apr 30, 2025
c9bddfc
kv-cache : avoid using the backends from the llama_context
ggerganov Apr 30, 2025
09195eb
kv-cache : more consistent debug logs [no ci]
ggerganov Apr 30, 2025
58e1d40
kv-cache : do not pass the full llama_context for kv graphs
ggerganov Apr 30, 2025
903e46f
kv-cache : remove comment
ggerganov May 2, 2025
00cde5f
kv-cache : ggml_rope_ext_inplace -> ggml_rope_ext
ggerganov May 2, 2025
7e79a42
kv-cache : fix recurrent multi-user case
ggerganov May 2, 2025
5883c90
memory : remove comments [no ci]
ggerganov May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
kv-cache : replace struct callbacks with llama_model &
ggml-ci
  • Loading branch information
ggerganov committed May 2, 2025
commit 43cbf38bfe7c69086d1b942f992e6ba1f094f8a7
1 change: 0 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,6 @@ void llama_context::kv_self_update() {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());

need_reserve = kv_self->update({
/*.arch =*/ model.arch,
/*.cparams =*/ cparams,
/*.sched =*/ sched.get(),
/*.backends =*/ backends,
Expand Down
55 changes: 38 additions & 17 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
}

llama_kv_cache_unified::llama_kv_cache_unified(
const llama_hparams & hparams,
callbacks cbs,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t padding) : cbs(std::move(cbs)), hparams(hparams), v_trans(v_trans), padding(padding) {
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
const int32_t n_layer = hparams.n_layer;

has_shift = false;
Expand Down Expand Up @@ -81,7 +81,18 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();

ggml_backend_buffer_type_t buft = this->cbs.get_buft(i);
const char * dev_name = "CPU";

ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();

if (!offload) {
auto * dev = model.dev_layer(i);
buft = ggml_backend_dev_buffer_type(dev);

dev_name = ggml_backend_dev_name(dev);
}

LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name);

ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) {
Expand Down Expand Up @@ -588,7 +599,6 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
float freq_base,
float freq_scale,
ggml_backend_buffer * bbuf) const {
const auto & arch = params.arch;
const auto & cparams = params.cparams;
const auto & backends = params.backends;
const auto & sched = params.sched;
Expand All @@ -604,7 +614,7 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(

// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float yarn_attn_factor = arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;

ggml_tensor * tmp;

Expand Down Expand Up @@ -697,7 +707,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;

ggml_tensor * rope_factors = cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

ggml_tensor * k =
ggml_view_3d(ctx, k_l[il],
Expand Down Expand Up @@ -1377,11 +1387,11 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
//

llama_kv_cache_recurrent::llama_kv_cache_recurrent(
const llama_hparams & hparams,
callbacks cbs,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size) : cbs(std::move(cbs)), hparams(hparams) {
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size) : hparams(model.hparams) {
const int32_t n_layer = hparams.n_layer;

LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
Expand Down Expand Up @@ -1429,7 +1439,18 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();

ggml_backend_buffer_type_t buft = this->cbs.get_buft(i);
const char * dev_name = "CPU";

ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();

if (!offload) {
auto * dev = model.dev_layer(i);
buft = ggml_backend_dev_buffer_type(dev);

dev_name = ggml_backend_dev_name(dev);
}

LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name);

ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) {
Expand Down
41 changes: 15 additions & 26 deletions 41 src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@ struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_sbatch;
struct llama_model;

struct llama_kv_cache : public llama_memory_i {
// can be used to query data from the model if needed
struct callbacks {
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;

// get the buffer type of layer il, can be used to offload KV cache layers to a different device
std::function<ggml_backend_buffer_type_t (int il)> get_buft;
};

struct graph_params {
const llm_arch arch;

const llama_cparams & cparams;

const ggml_backend_sched_t & sched;
Expand Down Expand Up @@ -139,13 +130,13 @@ class llama_kv_cache_unified : public llama_kv_cache {
static uint32_t get_padding(const llama_cparams & cparams);

llama_kv_cache_unified(
const llama_hparams & hparams,
callbacks cbs,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t padding);
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding);

~llama_kv_cache_unified() = default;

Expand Down Expand Up @@ -208,14 +199,13 @@ class llama_kv_cache_unified : public llama_kv_cache {
// computed before each graph build
uint32_t n = 0;

callbacks cbs;

std::vector<kv_cell> cells;

std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;

private:
const llama_model & model;
const llama_hparams & hparams;

bool has_shift = false;
Expand Down Expand Up @@ -312,11 +302,11 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
};

llama_kv_cache_recurrent(
const llama_hparams & hparams,
callbacks cbs,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size);
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size);

~llama_kv_cache_recurrent() = default;

Expand Down Expand Up @@ -370,8 +360,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;

callbacks cbs;

// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
Expand All @@ -388,6 +376,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
std::vector<ggml_tensor *> v_l;

private:
//const llama_model & model;
const llama_hparams & hparams;

// commit/restore cache
Expand Down
76 changes: 25 additions & 51 deletions src/llama-model.cpp
10000
Original file line number Diff line number Diff line change
Expand Up @@ -4445,6 +4445,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
return it->second;
}

ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}

if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}

return layers[il].rope_short;
}

struct llm_build_llama : public llm_graph_context {
llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
Expand Down Expand Up @@ -4485,7 +4498,7 @@ struct llm_build_llama : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
Expand Down Expand Up @@ -4710,7 +4723,7 @@ struct llm_build_deci : public llm_graph_context {
} else if (n_head > 0) {
// self-attention
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
Expand Down Expand Up @@ -7192,7 +7205,7 @@ struct llm_build_phi3 : public llm_graph_context {
// self-attention
{
// rope freq factors for 128k context
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

ggml_tensor* attn_norm_output = build_norm(inpL,
model.layers[il].attn_norm,
Expand Down Expand Up @@ -7944,7 +7957,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

// norm
cur = build_norm(inpL,
Expand Down Expand Up @@ -9012,7 +9025,7 @@ struct llm_build_cohere2 : public llm_graph_context {
// self-attention
{
// rope freq factors for 128k context
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
Expand Down Expand Up @@ -9950,7 +9963,7 @@ struct llm_build_deepseek : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
Expand Down Expand Up @@ -11314,7 +11327,7 @@ struct llm_build_exaone : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
Expand Down Expand Up @@ -12695,7 +12708,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);

// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
Expand Down Expand Up @@ -12818,28 +12831,6 @@ struct llm_build_bailingmoe : public llm_graph_context {
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res;

const bool offload = cparams.offload_kqv;

auto get_buft = [this, offload](int il) {
const char * dev_name = "CPU";

ggml_backend_buffer_type_t buft;
if (offload) {
auto * dev = dev_layer(il);
buft = ggml_backend_dev_buffer_type(dev);

dev_name = ggml_backend_dev_name(dev);
} else {
buft = ggml_backend_cpu_buffer_type();
}

LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", il, dev_name);

return buft;
};

LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);

switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
Expand All @@ -12848,13 +12839,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_ARWKV7:
{
res = new llama_kv_cache_recurrent(
hparams,
{
/*.get_rope_factors =*/ nullptr,
/*.get_buft =*/ get_buft,
},
*this,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max));
} break;
default:
Expand All @@ -12866,25 +12854,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);

res = new llama_kv_cache_unified(
hparams,
{
/*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}

if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}

return layers[il].rope_short;
},
/*.get_buft =*/ get_buft,
},
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
padding);
}
Expand Down
2 changes: 2 additions & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ struct llama_model {

const struct ggml_tensor * get_tensor(const char * name) const;

ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;

// note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
Expand Down
0