8000 kv-cache : separate recurrent vs non-recurrent impl by ggerganov · Pull Request #12799 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

kv-cache : separate recurrent vs non-recurrent impl #12799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
22bda48
kv-cache : serparate recurrent vs non-recurrent impl (wip)
ggerganov Apr 7, 2025
8145799
kv-cache : init -> contructor + add llama_memory_params
ggerganov Apr 15, 2025
49aa8b8
kv-cache : fix callback reference
ggerganov Apr 15, 2025
838b3cc
context : llama_kv_cache -> llama_memory_i
ggerganov Apr 17, 2025
8e4d3ba
context : move memory creation logic to model
ggerganov Apr 17, 2025
7fec081
llama : remove reference of memory during encode
ggerganov Apr 17, 2025
59af92b
kv-cache : hide padding details in the implementation
ggerganov Apr 23, 2025
6413b93
kv-cache : add ubatch_next()
ggerganov Apr 23, 2025
e869515
context : simplify sbatch logic
ggerganov Apr 23, 2025
ae2cd00
kv-cache : hide defrag logic in the implementation
ggerganov Apr 23, 2025
fdb7206
context : hide kv cache details in implementation
ggerganov Apr 23, 2025
13d69a5
build : fix
ggerganov Apr 23, 2025
5ef7559
cont : another fix
ggerganov Apr 23, 2025
6b50ba7
kv-cache : simplify interface (wip)
ggerganov Apr 24, 2025
cb02ac8
kv-cache : use separate KV cell structs for unified/recurrent
ggerganov Apr 24, 2025
f584750
kv-cache : clean-up
ggerganov Apr 24, 2025
458f2a5
model : better llama_model::create_model() signature
ggerganov Apr 24, 2025
92e626b
kv-cache : fix recurrent seq_rm()
ggerganov Apr 25, 2025
43cbf38
kv-cache : replace `struct callbacks` with `llama_model &`
ggerganov Apr 30, 2025
6619832
kv-cache : replace `struct graph_params` with `llama_context &`
ggerganov Apr 30, 2025
95a9f8b
kv-cache : fix offload check
ggerganov Apr 30, 2025
8737e65
context : avoid passing unique_ptr
ggerganov Apr 30, 2025
c9bddfc
kv-cache : avoid using the backends from the llama_context
ggerganov Apr 30, 2025
09195eb
kv-cache : more consistent debug logs [no ci]
ggerganov Apr 30, 2025
58e1d40
kv-cache : do not pass the full llama_context for kv graphs
ggerganov Apr 30, 2025
903e46f
kv-cache : remove comment
ggerganov May 2, 2025
00cde5f
kv-cache : ggml_rope_ext_inplace -> ggml_rope_ext
ggerganov May 2, 2025
7e79a42
kv-cache : fix recurrent multi-user case
ggerganov May 2, 2025
5883c90
memory : remove comments [no ci]
ggerganov May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
kv-cache : use separate KV cell structs for unified/recurrent
ggml-ci
  • Loading branch information
ggerganov committed May 2, 2025
commit cb02ac80861dbf04df8943d86b2984e43117a42f
78 changes: 36 additions & 42 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ void llama_kv_cache_unified::clear() {
for (int32_t i = 0; i < (int32_t) size; ++i) {
cells[i].pos = -1;
cells[i].seq_id.clear();
cells[i].src = -1;
cells[i].tail = -1;
}
head = 0;
used = 0;
Expand Down Expand Up @@ -190,7 +188,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
}

cells[i].pos = -1;
cells[i].src = -1;

if (new_head == size) {
new_head = i;
Expand Down Expand Up @@ -245,7 +242,6 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
}

10000 cells[i].pos = -1;
cells[i].src = -1;
cells[i].seq_id.clear();

if (new_head == size){
Expand Down Expand Up @@ -380,7 +376,6 @@ void llama_kv_cache_unified::restore() {
}

cells[i].pos = -1;
cells[i].src = -1;
}

new_head = std::min(new_head, range.c0);
Expand Down Expand Up @@ -847,7 +842,7 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {

uint32_t llama_kv_cache_unified::cell_max() const {
for (uint32_t i = size; i > 0; --i) {
const llama_kv_cell & cell = cells[i - 1];
const kv_cell & cell = cells[i - 1];

if (cell.pos >= 0 && !cell.is_empty()) {
return i;
Expand Down Expand Up @@ -983,7 +978,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
cells[i0 + nf] = cell1;

// clear the old cell and move the head there
cell1 = llama_kv_cell();
cell1 = kv_cell();
head = n_used;

if (!cont) {
Expand Down Expand Up @@ -1226,7 +1221,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
clear();

for (uint32_t i = 0; i < cell_count; ++i) {
llama_kv_cell & cell = cells[i];
kv_cell & cell = cells[i];

llama_pos pos;
uint32_t n_seq_id;
Expand Down Expand Up @@ -1538,7 +1533,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
if (0 <= seq_id) {
int32_t & tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
const llama_kv_cell & cell = cells[tail_id];
const kv_cell & cell = cells[tail_id];
// partial intersection is invalid
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
return false;
Expand Down Expand Up @@ -1572,23 +1567,22 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
}

if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
llama_kv_cell & tail_src = cells[seq_id_src];
llama_kv_cell & tail_dst = cells[seq_id_dst];
kv_cell & tail_src = cells[seq_id_src];
kv_cell & tail_dst = cells[seq_id_dst];
if (tail_dst.tail >= 0) {
// clear destination seq_id if it wasn't empty
llama_kv_cell & cell_dst = cells[tail_dst.tail];
kv_cell & cell_dst = cells[tail_dst.tail];

cell_dst.seq_id.erase(seq_id_dst);
tail_dst.tail = -1;
if (cell_dst.seq_id.empty()) {
cell_dst.pos = -1;
cell_dst.delta = -1;
cell_dst.src = -1;
used -= 1;
}
}
if (tail_src.tail >= 0) {
llama_kv_cell & cell_src = cells[tail_src.tail];
kv_cell & cell_src = cells[tail_src.tail];

cell_src.seq_id.insert(seq_id_dst);
tail_dst.tail = tail_src.tail;
Expand Down Expand Up @@ -1650,7 +1644,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
if (0 <= seq_id && seq_id < (int64_t) size) {
const int32_t tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
llama_kv_cell & cell = cells[tail_id];
kv_cell & cell = cells[tail_id];
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
cell.pos += delta;
}
Expand Down Expand Up @@ -1680,7 +1674,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
if (0 <= seq_id && seq_id < (int64_t) size) {
const int32_t tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
llama_kv_cell & cell = cells[tail_id];
kv_cell & cell = cells[tail_id];
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
cell.pos /= d;
}
Expand Down Expand Up @@ -1731,19 +1725,19 @@ int32_t llama_kv_cache_recurrent::s_copy(int i) const {

//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<llama_kv_cell &>(cells[i]);
kv_cell & cell = const_cast<kv_cell &>(cells[i]);

// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= size) {
kv_cell.src = cell_id;
if (cell.src < 0 || (uint32_t) cell.src >= size) {
cell.src = cell_id;
}

int32_t res = kv_cell.src;
int32_t res = cell.src;

// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
if (cell.src != (int32_t) cell_id) {
cell.src = cell_id;
}

return res;
Expand All @@ -1754,13 +1748,13 @@ float llama_kv_cache_recurrent::s_mask(int i) const {

//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<llama_kv_cell &>(cells[i]);
kv_cell & cell = const_cast<kv_cell &>(cells[i]);

float res = (float) (kv_cell.src >= 0);
float res = (float) (cell.src >= 0);

// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
if (cell.src < 0) {
cell.src = cell_id;
}

return res;
Expand Down Expand Up @@ -1802,9 +1796,9 @@ bool llama_kv_cache_recurrent::find_slot(
return false;
}
if (j > 0) {
llama_kv_cell & seq = cells[seq_id];
kv_cell & seq = cells[seq_id];
if (seq.tail >= 0) {
llama_kv_cell & cell = cells[seq.tail];
kv_cell & cell = cells[seq.tail];
// clear cells from seq_ids that become shared
// (should not normally happen, but let's handle it anyway)
cell.seq_id.erase(seq_id);
Expand All @@ -1824,7 +1818,7 @@ bool llama_kv_cache_recurrent::find_slot(
std::vector<int32_t> tails_verif;
tails_verif.assign(size, -1);
for (uint32_t i = 0; i < size; ++i) {
llama_kv_cell & cell = cells[i];
kv_cell & cell = cells[i];
for (llama_seq_id seq_id : cell.seq_id) {
if (tails_verif[seq_id] != -1) {
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
Expand All @@ -1845,28 +1839,28 @@ bool llama_kv_cache_recurrent::find_slot(

for (uint32_t i = 0; i < size; ++i) {
if (next_empty_cell >= size) { next_empty_cell -= size; }
llama_kv_cell & cell = cells[next_empty_cell];
kv_cell & cell = cells[next_empty_cell];
if (cell.is_empty()) { break; }
next_empty_cell += 1;
}

// find usable cell range
for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
llama_kv_cell & seq_meta = cells[seq_id];
kv_cell & seq_meta = cells[seq_id];
bool has_cell = false;
if (seq_meta.tail >= 0) {
llama_kv_cell & cell = cells[seq_meta.tail];
kv_cell & cell = cells[seq_meta.tail];
GGML_ASSERT(cell.has_seq_id(seq_id));
// does this seq_id "own" the cell?
if (cell.seq_id.size() == 1) { has_cell = true; }
}
if (!has_cell) {
llama_kv_cell & empty_cell = cells[next_empty_cell];
kv_cell & empty_cell = cells[next_empty_cell];
GGML_ASSERT(empty_cell.is_empty());
// copy old tail into the empty cell
if (seq_meta.tail >= 0) {
llama_kv_cell & orig_cell = cells[seq_meta.tail];
kv_cell & orig_cell = cells[seq_meta.tail];
empty_cell.pos = orig_cell.pos;
empty_cell.src = orig_cell.src;
orig_cell.seq_id.erase(seq_id);
Expand All @@ -1878,7 +1872,7 @@ bool llama_kv_cache_recurrent::find_slot(
next_empty_cell += 1;
for (uint32_t i = 0; i < size; ++i) {
if (next_empty_cell >= size) { next_empty_cell -= size; }
llama_kv_cell & cell = cells[next_empty_cell];
kv_cell & cell = cells[next_empty_cell];
if (cell.is_empty()) { break; }
next_empty_cell += 1;
}
Expand All @@ -1893,8 +1887,8 @@ bool llama_kv_cache_recurrent::find_slot(
int32_t dst_id = s + min;
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
if (dst_id != src_id) {
llama_kv_cell & dst_cell = cells[dst_id];
llama_kv_cell & src_cell = cells[src_id];
kv_cell & dst_cell = cells[dst_id];
kv_cell & src_cell = cells[src_id];

std::swap(dst_cell.pos, src_cell.pos);
std::swap(dst_cell.src, src_cell.src);
Expand All @@ -1914,7 +1908,7 @@ bool llama_kv_cache_recurrent::find_slot(
for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
int32_t cell_id = s + min;
llama_kv_cell & cell = cells[cell_id];
kv_cell & cell = cells[cell_id];

if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
// What should happen when the pos backtracks or skips a value?
Expand All @@ -1935,7 +1929,7 @@ bool llama_kv_cache_recurrent::find_slot(
head = min;
n = max - min + 1;
used = std::count_if(cells.begin(), cells.end(),
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
[](const kv_cell & cell){ return !cell.is_empty(); });

// sanity check
return n >= n_seqs;
Expand All @@ -1958,7 +1952,7 @@ llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32

uint32_t llama_kv_cache_recurrent::cell_max() const {
for (uint32_t i = size; i > 0; --i) {
const llama_kv_cell & cell = cells[i - 1];
const kv_cell & cell = cells[i - 1];

if (cell.pos >= 0 && !cell.is_empty()) {
return i;
Expand Down Expand Up @@ -2200,7 +2194,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
clear();

for (uint32_t i = 0; i < cell_count; ++i) {
llama_kv_cell & cell = cells[i];
kv_cell & cell = cells[i];

llama_pos pos;
uint32_t n_seq_id;
Expand Down Expand Up @@ -2412,7 +2406,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache
view->cells_sequences = (llama_seq_id *)p;
}

const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
llama_kv_cache_view_cell * c_curr = view->cells;
llama_seq_id * cs_curr = view->cells_sequences;
int32_t used_cells = 0;
Expand Down
66 changes: 41 additions & 25 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,29 +111,6 @@ struct llama_kv_cache_guard {
llama_kv_cache * kv;
};

// TODO: create separate cells for unified/recurrent caches
// TODO: move in the source file
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
int32_t src = -1; // used by recurrent state models to copy states
int32_t tail = -1;

std::set<llama_seq_id> seq_id;

bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}

bool is_empty() const {
return seq_id.empty();
}

bool is_same_seq(const llama_kv_cell & other) const {
return seq_id == other.seq_id;
}
};

//
// llama_kv_cache_unified
// ring-buffer of cached KV data
Expand All @@ -143,6 +120,25 @@ struct llama_kv_cell {
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;

std::set<llama_seq_id> seq_id;

bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}

bool is_empty() const {
return seq_id.empty();
}

bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};

llama_kv_cache_unified(
const llama_hparams & hparams,
callbacks cbs,
Expand Down Expand Up @@ -251,7 +247,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
// required padding
uint32_t padding = 1;

std::vector<llama_kv_cell> cells;
std::vector<kv_cell> cells;

std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
Expand Down Expand Up @@ -294,6 +290,26 @@ class llama_kv_cache_unified : public llama_kv_cache {

class llama_kv_cache_recurrent : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
int32_t src = -1; // used by recurrent state models to copy states
int32_t tail = -1;

std::set<llama_seq_id> seq_id;

bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}

bool is_empty() const {
return seq_id.empty();
}

bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};

llama_kv_cache_recurrent(
const llama_hparams & hparams,
callbacks cbs,
Expand Down Expand Up @@ -384,7 +400,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
// computed before each graph build
uint32_t n = 0;

std::vector<llama_kv_cell> cells;
std::vector<kv_cell> cells;

std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
Expand Down
0