8000 kv-cache : remove const_cast when setting inputs for s_copy · ggml-org/llama.cpp@94c3d53 · GitHub
[go: up one dir, main page]

Skip to content

Commit 94c3d53

Browse files
committed
kv-cache : remove const_cast when setting inputs for s_copy
And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy.
1 parent 791998b commit 94c3d53

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
286286
for (uint32_t i = 0; i < n_kv; ++i) {
287287
const uint32_t cell_id = i + kv_self->head;
288288

289-
//////////////////////////////////////////////
290-
// TODO: this should not mutate the KV cache !
291-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
289+
const llama_kv_cell & kv_cell = kv_self->cells[cell_id];
290+
291+
int32_t src = kv_cell.src0;
292292

293293
// prevent out-of-bound sources
294-
if (kv_cell.src < 0) {
294+
if (src < 0) {
295295
GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source
296-
kv_cell.src = kv_self->rs_z;
296+
src = kv_self->rs_z;
297297
}
298-
if ((uint32_t) kv_cell.src >= kv_self->size) {
298+
if ((uint32_t) src >= kv_self->size) {
299299
// ignore out-of-bound sources
300-
kv_cell.src = cell_id;
300+
src = cell_id;
301301
}
302302

303-
data[i] = kv_cell.src;
304-
305-
// TODO: do not mutate the KV cache
306-
// ensure copy only happens once
307-
if (kv_cell.src != (int32_t) cell_id) {
308-
kv_cell.src = cell_id;
309-
}
303+
data[i] = src;
310304
}
311305
}
312306
}

src/llama-kv-cache.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot(
665665
// Find first to-be-cleared cell
666666
rs_z = -1;
667667
for (int i = min; i <= max; ++i) {
668-
if (cells[i].src == -1) {
668+
if (rs_z < 0 && cells[i].src == -1) {
669669
rs_z = i;
670-
break;
671670
}
671+
// Stage the source ids for all used cells to allow correct seq_* behavior
672+
// and still make these values available when setting the inputs
673+
cells[i].src0 = cells[i].src;
674+
cells[i].src = i;
672675
}
673676

674677
// allow getting the range of used cells, from head to head + n

src/llama-kv-cache.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct llama_kv_cell {
4747
llama_pos pos = -1;
4848
llama_pos delta = 0;
4949
int32_t src = -1; // used by recurrent state models to copy states
50+
int32_t src0 = -1; // like src, but used when setting the inputs (allowing to copy once)
5051
int32_t tail = -1;
5152

5253
std::set<llama_seq_id> seq_id;

0 commit comments

Comments
 (0)
0