File tree 3 files changed +14
-16
lines changed 3 files changed +14
-16
lines changed Original file line number Diff line number Diff line change @@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
286
286
for (uint32_t i = 0 ; i < n_kv; ++i) {
287
287
const uint32_t cell_id = i + kv_self->head ;
288
288
289
- // ////////////////////////////////////////////
290
- // TODO: this should not mutate the KV cache !
291
- llama_kv_cell & kv_cell = const_cast < class llama_kv_cache_unified *>(kv_self)-> cells [i] ;
289
+ const llama_kv_cell & kv_cell = kv_self-> cells [cell_id];
290
+
291
+ int32_t src = kv_cell. src0 ;
292
292
293
293
// prevent out-of-bound sources
294
- if (kv_cell. src < 0 ) {
294
+ if (src < 0 ) {
295
295
GGML_ASSERT (kv_self->rs_z >= 0 ); // Need a valid zero-ed cell as a source
296
- kv_cell. src = kv_self->rs_z ;
296
+ src = kv_self->rs_z ;
297
297
}
298
- if ((uint32_t ) kv_cell. src >= kv_self->size ) {
298
+ if ((uint32_t ) src >= kv_self->size ) {
299
299
// ignore out-of-bound sources
300
- kv_cell. src = cell_id;
300
+ src = cell_id;
301
301
}
302
302
303
- data[i] = kv_cell.src ;
304
-
305
- // TODO: do not mutate the KV cache
306
- // ensure copy only happens once
307
- if (kv_cell.src != (int32_t ) cell_id) {
308
- kv_cell.src = cell_id;
309
- }
303
+ data[i] = src;
310
304
}
311
305
}
312
306
}
Original file line number Diff line number Diff line change @@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot(
665
665
// Find first to-be-cleared cell
666
666
rs_z = -1 ;
667
667
for (int i = min; i <= max; ++i) {
668
- if (cells[i].src == -1 ) {
668
+ if (rs_z < 0 && cells[i].src == -1 ) {
669
669
rs_z = i;
670
- break ;
671
670
}
671
+ // Stage the source ids for all used cells to allow correct seq_* behavior
672
+ // and still make these values available when setting the inputs
673
+ cells[i].src0 = cells[i].src ;
674
+ cells[i].src = i;
672
675
}
673
676
674
677
// allow getting the range of used cells, from head to head + n
Original file line number Diff line number Diff line change @@ -47,6 +47,7 @@ struct llama_kv_cell {
47
47
llama_pos pos = -1 ;
48
48
llama_pos delta = 0 ;
49
49
int32_t src = -1 ; // used by recurrent state models to copy states
50
+ int32_t src0 = -1 ; // like src, but used when setting the inputs (allowing to copy once)
50
51
int32_t tail = -1 ;
51
52
52
53
std::set<llama_seq_id> seq_id;
You can’t perform that action at this time.
0 commit comments