8000 kv-cache : adapt recurrent cache · ggml-org/llama.cpp@d23f887 · GitHub
[go: up one dir, main page]

Skip to content

Commit d23f887

Browse files
committed
kv-cache : adapt recurrent cache
ggml-ci
1 parent 051372c commit d23f887

File tree

4 files changed

+276
-366
lines changed

4 files changed

+276
-366
lines changed

src/llama-batch.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
#include <algorithm>
66

77
void llama_ubatch::update() {
8+
if (equal_seqs) {
9+
// TODO: for now don't compute min/max for recurrent batches since we don't need this.
10+
// the batches will be refactored anyway, so we'll fix this later
11+
return;
12+
}
13+
814
for (uint32_t i = 0; i < n_tokens; ++i) {
915
const llama_seq_id s = seq_id[i][0];
1016

@@ -24,26 +30,33 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
2430
break;
2531
}
2632
}
27-
ubatch_token.resize(!has_embd ? n_ubatch : 0);
28-
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
29-
ubatch_pos.resize(n_ubatch);
30-
ubatch_n_seq_id.resize(n_ubatch);
31-
ubatch_seq_id.resize(n_ubatch);
32-
ubatch_output.resize(n_ubatch);
33+
34+
udatas.push_back({});
35+
8000 36+
auto & udata = udatas.back();
37+
38+
udata.token.resize(!has_embd ? n_ubatch : 0);
39+
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
40+
udata.pos.resize(n_ubatch);
41+
udata.n_seq_id.resize(n_ubatch);
42+
udata.seq_id.resize(n_ubatch);
43+
udata.output.resize(n_ubatch);
44+
3345
llama_ubatch ubatch = {
3446
/*equal_seqs =*/ true,
3547
/*n_tokens =*/ 0,
3648
/*n_seq_tokens =*/ 0,
3749
/*n_seqs =*/ 0,
3850
/*seq_pos_min =*/ {-1},
3951
/*seq_pos_max =*/ {-1},
40-
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
41-
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
42-
/*pos =*/ ubatch_pos.data(),
43-
/*n_seq_id =*/ ubatch_n_seq_id.data(),
44-
/*seq_id =*/ ubatch_seq_id.data(),
45-
/*output =*/ ubatch_output.data(),
52+
/*token =*/ !has_embd ? udata.token.data() : nullptr,
53+
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
54+
/*pos =*/ udata.pos.data(),
55+
/*n_seq_id =*/ udata.n_seq_id.data(),
56+
/*seq_id =*/ udata.seq_id.data(),
57+
/*output =*/ udata.output.data(),
4658
};
59+
4760
return ubatch;
4861
}
4962

src/llama-batch.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,18 @@ struct llama_sbatch {
5555

5656
const llama_batch * batch = nullptr;
5757

58-
// buffers for the ubatch
59-
std::vector<llama_token> ubatch_token;
60-
std::vector<float> ubatch_embd;
61-
std::vector<llama_pos> ubatch_pos;
62-
std::vector<int32_t> ubatch_n_seq_id;
63-
std::vector<llama_seq_id *> ubatch_seq_id;
64-
std::vector<int8_t> ubatch_output;
58+
// buffers for the ubatches
59+
// TODO: very hacky, this needs a complete rework
60+
struct ubatch_data {
61+
std::vector<llama_token> token;
62+
std::vector<float> embd;
63+
std::vector<llama_pos> pos;
64+
std::vector<int32_t> n_seq_id;
65+
std::vector<llama_seq_id *> seq_id;
66+
std::vector<int8_t> output;
67+
};
68+
69+
std::vector<ubatch_data> udatas;
6570

6671
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
6772

0 commit comments

Comments
 (0)
0